Initial commit before Ralph loop
This commit is contained in:
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
Data processing module for NHS High-Cost Drug Patient Pathway Analysis Tool.
|
||||
|
||||
Contains SQLite database management, data loaders, and Snowflake integration.
|
||||
Handles the migration from CSV-based storage to SQLite for improved performance.
|
||||
|
||||
Submodules:
|
||||
database: SQLite connection management and schema definitions
|
||||
loader: Data loading abstractions (CSV, SQLite, Snowflake)
|
||||
snowflake_connector: Snowflake integration with SSO authentication
|
||||
"""
|
||||
|
||||
from data_processing.database import (
|
||||
DatabaseConfig,
|
||||
DatabaseManager,
|
||||
default_db_config,
|
||||
default_db_manager,
|
||||
)
|
||||
from data_processing.schema import (
|
||||
# Reference table schemas
|
||||
REF_DRUG_NAMES_SCHEMA,
|
||||
REF_ORGANIZATIONS_SCHEMA,
|
||||
REF_DIRECTORIES_SCHEMA,
|
||||
REF_DRUG_DIRECTORY_MAP_SCHEMA,
|
||||
REF_DRUG_INDICATION_CLUSTERS_SCHEMA,
|
||||
REFERENCE_TABLES_SCHEMA,
|
||||
# Fact table schemas
|
||||
FACT_INTERVENTIONS_SCHEMA,
|
||||
FACT_TABLES_SCHEMA,
|
||||
# Materialized view schemas
|
||||
MV_PATIENT_TREATMENT_SUMMARY_SCHEMA,
|
||||
MATERIALIZED_VIEWS_SCHEMA,
|
||||
# File tracking schemas
|
||||
PROCESSED_FILES_SCHEMA,
|
||||
FILE_TRACKING_SCHEMA,
|
||||
# Combined schema
|
||||
ALL_TABLES_SCHEMA,
|
||||
# Reference table functions
|
||||
create_reference_tables,
|
||||
drop_reference_tables,
|
||||
get_reference_table_counts,
|
||||
verify_reference_tables_exist,
|
||||
# Fact table functions
|
||||
create_fact_tables,
|
||||
drop_fact_tables,
|
||||
get_fact_table_counts,
|
||||
verify_fact_tables_exist,
|
||||
# File tracking functions
|
||||
create_file_tracking_tables,
|
||||
drop_file_tracking_tables,
|
||||
get_file_tracking_counts,
|
||||
verify_file_tracking_tables_exist,
|
||||
# Combined functions
|
||||
create_all_tables,
|
||||
drop_all_tables,
|
||||
get_all_table_counts,
|
||||
verify_all_tables_exist,
|
||||
)
|
||||
|
||||
# Reference data migration functions
|
||||
from data_processing.reference_data import (
|
||||
MigrationResult,
|
||||
migrate_drug_names,
|
||||
get_drug_name_counts,
|
||||
verify_drug_names_migration,
|
||||
migrate_organizations,
|
||||
get_organization_counts,
|
||||
verify_organizations_migration,
|
||||
migrate_directories,
|
||||
get_directory_counts,
|
||||
verify_directories_migration,
|
||||
migrate_drug_directory_map,
|
||||
get_drug_directory_map_counts,
|
||||
verify_drug_directory_map_migration,
|
||||
migrate_drug_indication_clusters,
|
||||
get_drug_indication_cluster_counts,
|
||||
verify_drug_indication_clusters_migration,
|
||||
)
|
||||
|
||||
# Data loader abstractions
|
||||
from data_processing.loader import (
|
||||
DataLoader,
|
||||
FileDataLoader,
|
||||
SQLiteDataLoader,
|
||||
LoadResult,
|
||||
get_loader,
|
||||
REQUIRED_COLUMNS,
|
||||
OPTIONAL_COLUMNS,
|
||||
)
|
||||
|
||||
# Patient data migration functions
|
||||
from data_processing.patient_data import (
|
||||
PatientDataLoadResult,
|
||||
load_patient_data,
|
||||
get_patient_data_stats,
|
||||
list_processed_files,
|
||||
calculate_file_hash,
|
||||
# Materialized view functions
|
||||
MVRefreshResult,
|
||||
refresh_patient_treatment_summary,
|
||||
get_patient_summary_stats,
|
||||
verify_mv_consistency,
|
||||
)
|
||||
|
||||
# Snowflake connector
|
||||
from data_processing.snowflake_connector import (
|
||||
SnowflakeConnector,
|
||||
SnowflakeConnectionError,
|
||||
SnowflakeNotConfiguredError,
|
||||
SnowflakeNotAvailableError,
|
||||
ConnectionInfo,
|
||||
get_connector,
|
||||
reset_connector,
|
||||
is_snowflake_available,
|
||||
is_snowflake_configured,
|
||||
SNOWFLAKE_AVAILABLE,
|
||||
)
|
||||
|
||||
# Query result caching
|
||||
from data_processing.cache import (
|
||||
QueryCache,
|
||||
CacheEntry,
|
||||
CacheStats,
|
||||
get_cache,
|
||||
reset_cache,
|
||||
is_cache_enabled,
|
||||
)
|
||||
|
||||
# Data source management with fallback chain
|
||||
from data_processing.data_source import (
|
||||
DataSourceType,
|
||||
DataSourceResult,
|
||||
SourceStatus,
|
||||
DataSourceManager,
|
||||
get_data_source_manager,
|
||||
get_data,
|
||||
reset_data_source_manager,
|
||||
)
|
||||
|
||||
# Diagnosis lookup (GP diagnosis validation)
|
||||
from data_processing.diagnosis_lookup import (
|
||||
ClusterSnomedCodes,
|
||||
IndicationValidationResult,
|
||||
DrugIndicationMatchRate,
|
||||
get_drug_clusters,
|
||||
get_drug_cluster_ids,
|
||||
get_cluster_snomed_codes,
|
||||
patient_has_indication,
|
||||
validate_indication,
|
||||
get_indication_match_rate,
|
||||
batch_validate_indications,
|
||||
get_available_clusters,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Database management
|
||||
"DatabaseConfig",
|
||||
"DatabaseManager",
|
||||
"default_db_config",
|
||||
"default_db_manager",
|
||||
# Reference table schemas
|
||||
"REF_DRUG_NAMES_SCHEMA",
|
||||
"REF_ORGANIZATIONS_SCHEMA",
|
||||
"REF_DIRECTORIES_SCHEMA",
|
||||
"REF_DRUG_DIRECTORY_MAP_SCHEMA",
|
||||
"REF_DRUG_INDICATION_CLUSTERS_SCHEMA",
|
||||
"REFERENCE_TABLES_SCHEMA",
|
||||
# Fact table schemas
|
||||
"FACT_INTERVENTIONS_SCHEMA",
|
||||
"FACT_TABLES_SCHEMA",
|
||||
# Materialized view schemas
|
||||
"MV_PATIENT_TREATMENT_SUMMARY_SCHEMA",
|
||||
"MATERIALIZED_VIEWS_SCHEMA",
|
||||
# File tracking schemas
|
||||
"PROCESSED_FILES_SCHEMA",
|
||||
"FILE_TRACKING_SCHEMA",
|
||||
# Combined schema
|
||||
"ALL_TABLES_SCHEMA",
|
||||
# Reference table functions
|
||||
"create_reference_tables",
|
||||
"drop_reference_tables",
|
||||
"get_reference_table_counts",
|
||||
"verify_reference_tables_exist",
|
||||
# Fact table functions
|
||||
"create_fact_tables",
|
||||
"drop_fact_tables",
|
||||
"get_fact_table_counts",
|
||||
"verify_fact_tables_exist",
|
||||
# File tracking functions
|
||||
"create_file_tracking_tables",
|
||||
"drop_file_tracking_tables",
|
||||
"get_file_tracking_counts",
|
||||
"verify_file_tracking_tables_exist",
|
||||
# Combined functions
|
||||
"create_all_tables",
|
||||
"drop_all_tables",
|
||||
"get_all_table_counts",
|
||||
"verify_all_tables_exist",
|
||||
# Reference data migration
|
||||
"MigrationResult",
|
||||
"migrate_drug_names",
|
||||
"get_drug_name_counts",
|
||||
"verify_drug_names_migration",
|
||||
"migrate_organizations",
|
||||
"get_organization_counts",
|
||||
"verify_organizations_migration",
|
||||
"migrate_directories",
|
||||
"get_directory_counts",
|
||||
"verify_directories_migration",
|
||||
"migrate_drug_directory_map",
|
||||
"get_drug_directory_map_counts",
|
||||
"verify_drug_directory_map_migration",
|
||||
"migrate_drug_indication_clusters",
|
||||
"get_drug_indication_cluster_counts",
|
||||
"verify_drug_indication_clusters_migration",
|
||||
# Data loader abstractions
|
||||
"DataLoader",
|
||||
"FileDataLoader",
|
||||
"SQLiteDataLoader",
|
||||
"LoadResult",
|
||||
"get_loader",
|
||||
"REQUIRED_COLUMNS",
|
||||
"OPTIONAL_COLUMNS",
|
||||
# Patient data migration
|
||||
"PatientDataLoadResult",
|
||||
"load_patient_data",
|
||||
"get_patient_data_stats",
|
||||
"list_processed_files",
|
||||
"calculate_file_hash",
|
||||
# Materialized view functions
|
||||
"MVRefreshResult",
|
||||
"refresh_patient_treatment_summary",
|
||||
"get_patient_summary_stats",
|
||||
"verify_mv_consistency",
|
||||
# Snowflake connector
|
||||
"SnowflakeConnector",
|
||||
"SnowflakeConnectionError",
|
||||
"SnowflakeNotConfiguredError",
|
||||
"SnowflakeNotAvailableError",
|
||||
"ConnectionInfo",
|
||||
"get_connector",
|
||||
"reset_connector",
|
||||
"is_snowflake_available",
|
||||
"is_snowflake_configured",
|
||||
"SNOWFLAKE_AVAILABLE",
|
||||
# Query result caching
|
||||
"QueryCache",
|
||||
"CacheEntry",
|
||||
"CacheStats",
|
||||
"get_cache",
|
||||
"reset_cache",
|
||||
"is_cache_enabled",
|
||||
# Data source management with fallback chain
|
||||
"DataSourceType",
|
||||
"DataSourceResult",
|
||||
"SourceStatus",
|
||||
"DataSourceManager",
|
||||
"get_data_source_manager",
|
||||
"get_data",
|
||||
"reset_data_source_manager",
|
||||
# Diagnosis lookup
|
||||
"ClusterSnomedCodes",
|
||||
"IndicationValidationResult",
|
||||
"DrugIndicationMatchRate",
|
||||
"get_drug_clusters",
|
||||
"get_drug_cluster_ids",
|
||||
"get_cluster_snomed_codes",
|
||||
"patient_has_indication",
|
||||
"validate_indication",
|
||||
"get_indication_match_rate",
|
||||
"batch_validate_indications",
|
||||
"get_available_clusters",
|
||||
]
|
||||
@@ -0,0 +1,553 @@
|
||||
"""
|
||||
Query result caching module for NHS Patient Pathway Analysis.
|
||||
|
||||
Provides file-based caching for Snowflake query results with TTL-based invalidation.
|
||||
Supports different TTLs for historical data vs data including the current date.
|
||||
|
||||
Cache keys are generated from query hashes. Results are stored as compressed JSON.
|
||||
|
||||
Usage:
|
||||
from data_processing.cache import QueryCache, get_cache
|
||||
|
||||
cache = get_cache()
|
||||
|
||||
# Check for cached result
|
||||
result = cache.get(query, params)
|
||||
if result is None:
|
||||
# Execute query and cache result
|
||||
result = execute_query(query, params)
|
||||
cache.set(query, params, result, includes_current_data=False)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, date
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
import gzip
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
from config import get_snowflake_config, CacheConfig
|
||||
from core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""Metadata for a cached query result."""
|
||||
cache_key: str
|
||||
query_hash: str
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
includes_current_data: bool
|
||||
row_count: int
|
||||
file_size_bytes: int
|
||||
file_path: Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats:
|
||||
"""Statistics about the cache."""
|
||||
enabled: bool
|
||||
cache_dir: Path
|
||||
total_entries: int
|
||||
total_size_mb: float
|
||||
max_size_mb: int
|
||||
oldest_entry: Optional[datetime]
|
||||
newest_entry: Optional[datetime]
|
||||
hit_count: int
|
||||
miss_count: int
|
||||
|
||||
|
||||
class QueryCache:
|
||||
"""
|
||||
File-based cache for Snowflake query results.
|
||||
|
||||
Results are stored as gzipped JSON files with TTL-based expiration.
|
||||
Supports different TTLs for historical vs current data.
|
||||
|
||||
Attributes:
|
||||
config: CacheConfig with cache settings
|
||||
cache_dir: Path to cache directory
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[CacheConfig] = None, base_path: Optional[Path] = None):
|
||||
"""
|
||||
Initialize the query cache.
|
||||
|
||||
Args:
|
||||
config: Optional CacheConfig. If not provided, loads from snowflake.toml
|
||||
base_path: Base path for relative cache directory. Defaults to cwd.
|
||||
"""
|
||||
if config is None:
|
||||
sf_config = get_snowflake_config()
|
||||
config = sf_config.cache
|
||||
|
||||
self._config = config
|
||||
self._base_path = base_path or Path.cwd()
|
||||
|
||||
# Resolve cache directory
|
||||
cache_dir = Path(config.directory)
|
||||
if not cache_dir.is_absolute():
|
||||
cache_dir = self._base_path / cache_dir
|
||||
self._cache_dir = cache_dir
|
||||
|
||||
# Stats tracking (in-memory only, reset on restart)
|
||||
self._hit_count = 0
|
||||
self._miss_count = 0
|
||||
|
||||
# Ensure cache directory exists if enabled
|
||||
if self._config.enabled:
|
||||
self._cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@property
|
||||
def config(self) -> CacheConfig:
|
||||
"""Return the cache configuration."""
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def cache_dir(self) -> Path:
|
||||
"""Return the cache directory path."""
|
||||
return self._cache_dir
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
"""Return True if caching is enabled."""
|
||||
return self._config.enabled
|
||||
|
||||
def _generate_cache_key(self, query: str, params: Optional[tuple] = None) -> str:
|
||||
"""
|
||||
Generate a cache key from query and parameters.
|
||||
|
||||
Uses SHA256 hash of query + params to create unique key.
|
||||
"""
|
||||
# Normalize query (strip whitespace, lowercase)
|
||||
normalized_query = " ".join(query.lower().split())
|
||||
|
||||
# Combine query and params
|
||||
key_content = normalized_query
|
||||
if params:
|
||||
key_content += "|" + "|".join(str(p) for p in params)
|
||||
|
||||
# Hash to create key
|
||||
hash_obj = hashlib.sha256(key_content.encode("utf-8"))
|
||||
return hash_obj.hexdigest()[:32] # Use first 32 chars for readability
|
||||
|
||||
def _get_cache_file_path(self, cache_key: str) -> Path:
|
||||
"""Get the file path for a cache entry."""
|
||||
return self._cache_dir / f"{cache_key}.json.gz"
|
||||
|
||||
def _get_meta_file_path(self, cache_key: str) -> Path:
|
||||
"""Get the metadata file path for a cache entry."""
|
||||
return self._cache_dir / f"{cache_key}.meta.json"
|
||||
|
||||
def _is_expired(self, meta: dict) -> bool:
|
||||
"""Check if a cache entry is expired based on its metadata."""
|
||||
expires_at = datetime.fromisoformat(meta["expires_at"])
|
||||
return datetime.now() > expires_at
|
||||
|
||||
def get(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[tuple] = None,
|
||||
check_expiry: bool = True
|
||||
) -> Optional[list[dict]]:
|
||||
"""
|
||||
Get a cached query result.
|
||||
|
||||
Args:
|
||||
query: SQL query string
|
||||
params: Optional query parameters
|
||||
check_expiry: If True, returns None for expired entries
|
||||
|
||||
Returns:
|
||||
Cached result as list of dicts, or None if not cached/expired
|
||||
"""
|
||||
if not self.is_enabled:
|
||||
self._miss_count += 1
|
||||
return None
|
||||
|
||||
cache_key = self._generate_cache_key(query, params)
|
||||
cache_file = self._get_cache_file_path(cache_key)
|
||||
meta_file = self._get_meta_file_path(cache_key)
|
||||
|
||||
# Check if files exist
|
||||
if not cache_file.exists() or not meta_file.exists():
|
||||
self._miss_count += 1
|
||||
logger.debug(f"Cache miss (not found): {cache_key}")
|
||||
return None
|
||||
|
||||
# Load and check metadata
|
||||
try:
|
||||
with open(meta_file, "r", encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
if check_expiry and self._is_expired(meta):
|
||||
self._miss_count += 1
|
||||
logger.debug(f"Cache miss (expired): {cache_key}")
|
||||
return None
|
||||
|
||||
# Load cached data
|
||||
with gzip.open(cache_file, "rt", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
self._hit_count += 1
|
||||
logger.info(f"Cache hit: {cache_key} ({meta['row_count']} rows)")
|
||||
return data
|
||||
|
||||
except (json.JSONDecodeError, KeyError, OSError) as e:
|
||||
logger.warning(f"Cache read error for {cache_key}: {e}")
|
||||
self._miss_count += 1
|
||||
# Clean up corrupted entry
|
||||
self._delete_entry(cache_key)
|
||||
return None
|
||||
|
||||
def set(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[tuple],
|
||||
data: list[dict],
|
||||
includes_current_data: bool = False,
|
||||
custom_ttl_seconds: Optional[int] = None
|
||||
) -> Optional[CacheEntry]:
|
||||
"""
|
||||
Cache a query result.
|
||||
|
||||
Args:
|
||||
query: SQL query string
|
||||
params: Optional query parameters
|
||||
data: Query result as list of dicts
|
||||
includes_current_data: If True, uses shorter TTL for current data
|
||||
custom_ttl_seconds: Optional custom TTL (overrides config)
|
||||
|
||||
Returns:
|
||||
CacheEntry with metadata, or None if caching disabled/failed
|
||||
"""
|
||||
if not self.is_enabled:
|
||||
return None
|
||||
|
||||
cache_key = self._generate_cache_key(query, params)
|
||||
cache_file = self._get_cache_file_path(cache_key)
|
||||
meta_file = self._get_meta_file_path(cache_key)
|
||||
|
||||
# Determine TTL
|
||||
if custom_ttl_seconds is not None:
|
||||
ttl = custom_ttl_seconds
|
||||
elif includes_current_data:
|
||||
ttl = self._config.ttl_current_data_seconds
|
||||
else:
|
||||
ttl = self._config.ttl_seconds
|
||||
|
||||
now = datetime.now()
|
||||
expires_at = datetime.fromtimestamp(now.timestamp() + ttl)
|
||||
|
||||
try:
|
||||
# Write compressed data
|
||||
with gzip.open(cache_file, "wt", encoding="utf-8", compresslevel=6) as f:
|
||||
json.dump(data, f, default=str)
|
||||
|
||||
file_size = cache_file.stat().st_size
|
||||
|
||||
# Write metadata
|
||||
meta = {
|
||||
"cache_key": cache_key,
|
||||
"query_hash": hashlib.sha256(query.encode()).hexdigest()[:16],
|
||||
"created_at": now.isoformat(),
|
||||
"expires_at": expires_at.isoformat(),
|
||||
"includes_current_data": includes_current_data,
|
||||
"row_count": len(data),
|
||||
"file_size_bytes": file_size,
|
||||
"ttl_seconds": ttl,
|
||||
}
|
||||
|
||||
with open(meta_file, "w", encoding="utf-8") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
logger.info(f"Cached {len(data)} rows as {cache_key} (expires in {ttl}s)")
|
||||
|
||||
# Check if we need to enforce size limit
|
||||
self._enforce_size_limit()
|
||||
|
||||
return CacheEntry(
|
||||
cache_key=cache_key,
|
||||
query_hash=str(meta["query_hash"]),
|
||||
created_at=now,
|
||||
expires_at=expires_at,
|
||||
includes_current_data=includes_current_data,
|
||||
row_count=len(data),
|
||||
file_size_bytes=file_size,
|
||||
file_path=cache_file,
|
||||
)
|
||||
|
||||
except (OSError, TypeError) as e:
|
||||
logger.error(f"Failed to cache result: {e}")
|
||||
return None
|
||||
|
||||
def invalidate(self, query: str, params: Optional[tuple] = None) -> bool:
|
||||
"""
|
||||
Invalidate a specific cache entry.
|
||||
|
||||
Args:
|
||||
query: SQL query string
|
||||
params: Optional query parameters
|
||||
|
||||
Returns:
|
||||
True if entry was deleted, False if not found
|
||||
"""
|
||||
cache_key = self._generate_cache_key(query, params)
|
||||
return self._delete_entry(cache_key)
|
||||
|
||||
def _delete_entry(self, cache_key: str) -> bool:
|
||||
"""Delete a cache entry by key."""
|
||||
cache_file = self._get_cache_file_path(cache_key)
|
||||
meta_file = self._get_meta_file_path(cache_key)
|
||||
|
||||
deleted = False
|
||||
|
||||
if cache_file.exists():
|
||||
cache_file.unlink()
|
||||
deleted = True
|
||||
|
||||
if meta_file.exists():
|
||||
meta_file.unlink()
|
||||
deleted = True
|
||||
|
||||
if deleted:
|
||||
logger.debug(f"Deleted cache entry: {cache_key}")
|
||||
|
||||
return deleted
|
||||
|
||||
def clear(self) -> int:
|
||||
"""
|
||||
Clear all cache entries.
|
||||
|
||||
Returns:
|
||||
Number of entries deleted
|
||||
"""
|
||||
if not self._cache_dir.exists():
|
||||
return 0
|
||||
|
||||
count = 0
|
||||
for file in self._cache_dir.glob("*.json*"):
|
||||
try:
|
||||
file.unlink()
|
||||
count += 1
|
||||
except OSError as e:
|
||||
logger.warning(f"Failed to delete {file}: {e}")
|
||||
|
||||
# Reset stats
|
||||
self._hit_count = 0
|
||||
self._miss_count = 0
|
||||
|
||||
logger.info(f"Cleared {count} cache files")
|
||||
return count // 2 # Divide by 2 since we have .json.gz and .meta.json
|
||||
|
||||
def clear_expired(self) -> int:
|
||||
"""
|
||||
Remove expired cache entries.
|
||||
|
||||
Returns:
|
||||
Number of expired entries deleted
|
||||
"""
|
||||
if not self._cache_dir.exists():
|
||||
return 0
|
||||
|
||||
count = 0
|
||||
for meta_file in self._cache_dir.glob("*.meta.json"):
|
||||
try:
|
||||
with open(meta_file, "r", encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
if self._is_expired(meta):
|
||||
cache_key = meta_file.stem.replace(".meta", "")
|
||||
self._delete_entry(cache_key)
|
||||
count += 1
|
||||
except (OSError, json.JSONDecodeError):
|
||||
# Delete corrupted metadata files
|
||||
cache_key = meta_file.stem.replace(".meta", "")
|
||||
self._delete_entry(cache_key)
|
||||
count += 1
|
||||
|
||||
logger.info(f"Cleared {count} expired cache entries")
|
||||
return count
|
||||
|
||||
def _get_total_size_mb(self) -> float:
|
||||
"""Calculate total cache size in MB."""
|
||||
if not self._cache_dir.exists():
|
||||
return 0.0
|
||||
|
||||
total_bytes = sum(
|
||||
f.stat().st_size
|
||||
for f in self._cache_dir.glob("*")
|
||||
if f.is_file()
|
||||
)
|
||||
return total_bytes / (1024 * 1024)
|
||||
|
||||
def _enforce_size_limit(self) -> int:
|
||||
"""
|
||||
Enforce cache size limit by removing oldest entries.
|
||||
|
||||
Returns:
|
||||
Number of entries removed
|
||||
"""
|
||||
max_size_mb = self._config.max_size_mb
|
||||
current_size_mb = self._get_total_size_mb()
|
||||
|
||||
if current_size_mb <= max_size_mb:
|
||||
return 0
|
||||
|
||||
# Get all entries sorted by creation time
|
||||
entries = []
|
||||
for meta_file in self._cache_dir.glob("*.meta.json"):
|
||||
try:
|
||||
with open(meta_file, "r", encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
entries.append((
|
||||
meta_file.stem.replace(".meta", ""),
|
||||
datetime.fromisoformat(meta["created_at"]),
|
||||
meta.get("file_size_bytes", 0)
|
||||
))
|
||||
except (OSError, json.JSONDecodeError, KeyError):
|
||||
# Clean up corrupted entry
|
||||
cache_key = meta_file.stem.replace(".meta", "")
|
||||
self._delete_entry(cache_key)
|
||||
|
||||
# Sort by creation time (oldest first)
|
||||
entries.sort(key=lambda x: x[1])
|
||||
|
||||
# Remove oldest entries until under limit
|
||||
removed = 0
|
||||
size_to_remove_bytes = (current_size_mb - max_size_mb * 0.9) * 1024 * 1024 # Target 90% of limit
|
||||
removed_bytes = 0
|
||||
|
||||
for cache_key, created_at, file_size in entries:
|
||||
if removed_bytes >= size_to_remove_bytes:
|
||||
break
|
||||
|
||||
self._delete_entry(cache_key)
|
||||
removed_bytes += file_size
|
||||
removed += 1
|
||||
|
||||
logger.info(f"Removed {removed} cache entries to enforce size limit")
|
||||
return removed
|
||||
|
||||
def get_stats(self) -> CacheStats:
|
||||
"""Get cache statistics."""
|
||||
if not self._cache_dir.exists():
|
||||
return CacheStats(
|
||||
enabled=self.is_enabled,
|
||||
cache_dir=self._cache_dir,
|
||||
total_entries=0,
|
||||
total_size_mb=0.0,
|
||||
max_size_mb=self._config.max_size_mb,
|
||||
oldest_entry=None,
|
||||
newest_entry=None,
|
||||
hit_count=self._hit_count,
|
||||
miss_count=self._miss_count,
|
||||
)
|
||||
|
||||
entries = []
|
||||
for meta_file in self._cache_dir.glob("*.meta.json"):
|
||||
try:
|
||||
with open(meta_file, "r", encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
entries.append(datetime.fromisoformat(meta["created_at"]))
|
||||
except (OSError, json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
|
||||
oldest = min(entries) if entries else None
|
||||
newest = max(entries) if entries else None
|
||||
|
||||
return CacheStats(
|
||||
enabled=self.is_enabled,
|
||||
cache_dir=self._cache_dir,
|
||||
total_entries=len(entries),
|
||||
total_size_mb=self._get_total_size_mb(),
|
||||
max_size_mb=self._config.max_size_mb,
|
||||
oldest_entry=oldest,
|
||||
newest_entry=newest,
|
||||
hit_count=self._hit_count,
|
||||
miss_count=self._miss_count,
|
||||
)
|
||||
|
||||
def list_entries(self) -> list[CacheEntry]:
|
||||
"""List all cache entries with metadata."""
|
||||
if not self._cache_dir.exists():
|
||||
return []
|
||||
|
||||
entries = []
|
||||
for meta_file in self._cache_dir.glob("*.meta.json"):
|
||||
try:
|
||||
with open(meta_file, "r", encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
cache_key = meta["cache_key"]
|
||||
entries.append(CacheEntry(
|
||||
cache_key=cache_key,
|
||||
query_hash=meta.get("query_hash", ""),
|
||||
created_at=datetime.fromisoformat(meta["created_at"]),
|
||||
expires_at=datetime.fromisoformat(meta["expires_at"]),
|
||||
includes_current_data=meta.get("includes_current_data", False),
|
||||
row_count=meta.get("row_count", 0),
|
||||
file_size_bytes=meta.get("file_size_bytes", 0),
|
||||
file_path=self._get_cache_file_path(cache_key),
|
||||
))
|
||||
except (OSError, json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
|
||||
# Sort by creation time (newest first)
|
||||
entries.sort(key=lambda x: x.created_at, reverse=True)
|
||||
return entries
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
_default_cache: Optional[QueryCache] = None
|
||||
|
||||
|
||||
def get_cache(config: Optional[CacheConfig] = None) -> QueryCache:
|
||||
"""
|
||||
Get a QueryCache instance (creates singleton on first call).
|
||||
|
||||
Args:
|
||||
config: Optional CacheConfig. If provided, creates new cache with
|
||||
this config. If None, uses/creates default cache.
|
||||
|
||||
Returns:
|
||||
QueryCache instance
|
||||
"""
|
||||
global _default_cache
|
||||
|
||||
if config is not None:
|
||||
# Custom config requested, create new cache
|
||||
return QueryCache(config)
|
||||
|
||||
if _default_cache is None:
|
||||
_default_cache = QueryCache()
|
||||
|
||||
return _default_cache
|
||||
|
||||
|
||||
def reset_cache() -> None:
|
||||
"""Reset the default cache singleton."""
|
||||
global _default_cache
|
||||
_default_cache = None
|
||||
|
||||
|
||||
def is_cache_enabled() -> bool:
|
||||
"""Return True if caching is enabled in configuration."""
|
||||
config = get_snowflake_config()
|
||||
return config.cache.enabled
|
||||
|
||||
|
||||
# Export public API
|
||||
__all__ = [
|
||||
"QueryCache",
|
||||
"CacheEntry",
|
||||
"CacheStats",
|
||||
"get_cache",
|
||||
"reset_cache",
|
||||
"is_cache_enabled",
|
||||
]
|
||||
@@ -0,0 +1,968 @@
|
||||
"""
|
||||
Unified data access layer with fallback chain for NHS Patient Pathway Analysis.
|
||||
|
||||
Provides a high-level interface that automatically selects the best available data source:
|
||||
1. Cache - Returns cached results if valid and not expired
|
||||
2. Snowflake - Queries Snowflake warehouse if configured and connected
|
||||
3. Local - Falls back to SQLite database or CSV/Parquet files
|
||||
|
||||
The fallback chain handles connection errors, missing configurations, and
|
||||
unavailable services gracefully, always attempting to provide data from
|
||||
some source.
|
||||
|
||||
Usage:
|
||||
from data_processing.data_source import DataSourceManager, get_data
|
||||
|
||||
# Simple usage with automatic source selection
|
||||
result = get_data(
|
||||
start_date=date(2024, 1, 1),
|
||||
end_date=date(2024, 12, 31),
|
||||
trusts=["TRUST A", "TRUST B"],
|
||||
)
|
||||
|
||||
# Or with explicit source preference
|
||||
manager = DataSourceManager()
|
||||
result = manager.get_data(
|
||||
start_date=date(2024, 1, 1),
|
||||
end_date=date(2024, 12, 31),
|
||||
preferred_source="snowflake",
|
||||
)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date, datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional, Callable
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DataSourceType(Enum):
|
||||
"""Enumeration of available data sources."""
|
||||
CACHE = "cache"
|
||||
SNOWFLAKE = "snowflake"
|
||||
SQLITE = "sqlite"
|
||||
FILE = "file"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataSourceResult:
|
||||
"""Result from data source query.
|
||||
|
||||
Attributes:
|
||||
df: The loaded DataFrame with patient intervention data
|
||||
source_type: Which data source was used
|
||||
source_detail: Additional details about the source (e.g., file path, query hash)
|
||||
row_count: Number of rows returned
|
||||
cached: Whether the result came from cache
|
||||
from_fallback: Whether a fallback source was used
|
||||
load_time_seconds: Time taken to load data
|
||||
warnings: Any warnings generated during loading
|
||||
"""
|
||||
df: pd.DataFrame
|
||||
source_type: DataSourceType
|
||||
source_detail: str = ""
|
||||
row_count: int = 0
|
||||
cached: bool = False
|
||||
from_fallback: bool = False
|
||||
load_time_seconds: float = 0.0
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.row_count == 0 and self.df is not None:
|
||||
self.row_count = len(self.df)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SourceStatus:
|
||||
"""Status of a data source.
|
||||
|
||||
Attributes:
|
||||
source_type: The type of data source
|
||||
available: Whether the source is available
|
||||
configured: Whether the source is properly configured
|
||||
message: Status message explaining the state
|
||||
last_checked: When the status was last checked
|
||||
"""
|
||||
source_type: DataSourceType
|
||||
available: bool = False
|
||||
configured: bool = False
|
||||
message: str = ""
|
||||
last_checked: Optional[datetime] = None
|
||||
|
||||
|
||||
class DataSourceManager:
|
||||
"""
|
||||
Manages data access with automatic fallback between sources.
|
||||
|
||||
The manager attempts to retrieve data from sources in order of preference:
|
||||
1. Cache (if enabled and has valid cached data)
|
||||
2. Snowflake (if configured and connected)
|
||||
3. SQLite (if database exists with data)
|
||||
4. Local files (CSV/Parquet)
|
||||
|
||||
Attributes:
|
||||
cache_enabled: Whether to use caching
|
||||
local_file_path: Path to local CSV/Parquet file (optional fallback)
|
||||
sqlite_db_path: Path to SQLite database (optional)
|
||||
|
||||
Example:
|
||||
manager = DataSourceManager()
|
||||
|
||||
# Check what sources are available
|
||||
status = manager.check_all_sources()
|
||||
for s in status:
|
||||
print(f"{s.source_type.value}: {s.message}")
|
||||
|
||||
# Get data with automatic fallback
|
||||
result = manager.get_data(
|
||||
start_date=date(2024, 1, 1),
|
||||
end_date=date(2024, 6, 30),
|
||||
)
|
||||
print(f"Got {result.row_count} rows from {result.source_type.value}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_enabled: bool = True,
|
||||
local_file_path: Optional[Path | str] = None,
|
||||
sqlite_db_path: Optional[Path | str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the data source manager.
|
||||
|
||||
Args:
|
||||
cache_enabled: Whether to check cache before querying (default True)
|
||||
local_file_path: Path to local CSV/Parquet file for file fallback
|
||||
sqlite_db_path: Path to SQLite database (uses default if None)
|
||||
"""
|
||||
self._cache_enabled = cache_enabled
|
||||
self._local_file_path = Path(local_file_path) if local_file_path else None
|
||||
self._sqlite_db_path = Path(sqlite_db_path) if sqlite_db_path else None
|
||||
self._source_status: dict[DataSourceType, SourceStatus] = {}
|
||||
|
||||
@property
|
||||
def cache_enabled(self) -> bool:
|
||||
"""Return whether caching is enabled."""
|
||||
return self._cache_enabled
|
||||
|
||||
@cache_enabled.setter
|
||||
def cache_enabled(self, value: bool):
|
||||
"""Set whether caching is enabled."""
|
||||
self._cache_enabled = value
|
||||
|
||||
def _check_cache_status(self) -> SourceStatus:
|
||||
"""Check if cache is available."""
|
||||
try:
|
||||
from data_processing.cache import is_cache_enabled, get_cache
|
||||
|
||||
if not is_cache_enabled():
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.CACHE,
|
||||
available=False,
|
||||
configured=False,
|
||||
message="Cache disabled in configuration",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
cache = get_cache()
|
||||
stats = cache.get_stats()
|
||||
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.CACHE,
|
||||
available=True,
|
||||
configured=True,
|
||||
message=f"Cache enabled ({stats.total_entries} entries, {stats.total_size_mb:.1f}MB)",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
except Exception as e:
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.CACHE,
|
||||
available=False,
|
||||
configured=False,
|
||||
message=f"Cache error: {e}",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
def _check_snowflake_status(self) -> SourceStatus:
|
||||
"""Check if Snowflake is available and configured."""
|
||||
try:
|
||||
from data_processing.snowflake_connector import (
|
||||
is_snowflake_available,
|
||||
is_snowflake_configured,
|
||||
)
|
||||
|
||||
if not is_snowflake_available():
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SNOWFLAKE,
|
||||
available=False,
|
||||
configured=False,
|
||||
message="snowflake-connector-python not installed",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
if not is_snowflake_configured():
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SNOWFLAKE,
|
||||
available=True,
|
||||
configured=False,
|
||||
message="Snowflake account not configured in config/snowflake.toml",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SNOWFLAKE,
|
||||
available=True,
|
||||
configured=True,
|
||||
message="Snowflake configured and ready",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
except Exception as e:
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SNOWFLAKE,
|
||||
available=False,
|
||||
configured=False,
|
||||
message=f"Snowflake error: {e}",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
def _check_sqlite_status(self) -> SourceStatus:
|
||||
"""Check if SQLite database is available with data."""
|
||||
try:
|
||||
from data_processing.database import default_db_manager, default_db_config
|
||||
|
||||
db_path = self._sqlite_db_path or Path(default_db_config.db_path)
|
||||
|
||||
if not db_path.exists():
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SQLITE,
|
||||
available=False,
|
||||
configured=True,
|
||||
message=f"Database not found: {db_path}",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
from data_processing.database import DatabaseManager, DatabaseConfig
|
||||
|
||||
config = DatabaseConfig(db_path=db_path)
|
||||
manager = DatabaseManager(config)
|
||||
|
||||
if not manager.table_exists("fact_interventions"):
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SQLITE,
|
||||
available=False,
|
||||
configured=True,
|
||||
message="fact_interventions table not found",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
count = manager.get_table_count("fact_interventions")
|
||||
if count == 0:
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SQLITE,
|
||||
available=False,
|
||||
configured=True,
|
||||
message="fact_interventions table is empty",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SQLITE,
|
||||
available=True,
|
||||
configured=True,
|
||||
message=f"SQLite database ready ({count:,} rows)",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
except Exception as e:
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SQLITE,
|
||||
available=False,
|
||||
configured=False,
|
||||
message=f"SQLite error: {e}",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
def _check_file_status(self) -> SourceStatus:
|
||||
"""Check if local file is available."""
|
||||
if self._local_file_path is None:
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.FILE,
|
||||
available=False,
|
||||
configured=False,
|
||||
message="No local file path configured",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
if not self._local_file_path.exists():
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.FILE,
|
||||
available=False,
|
||||
configured=True,
|
||||
message=f"File not found: {self._local_file_path}",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
size_mb = self._local_file_path.stat().st_size / (1024 * 1024)
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.FILE,
|
||||
available=True,
|
||||
configured=True,
|
||||
message=f"Local file ready: {self._local_file_path.name} ({size_mb:.1f}MB)",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
def check_source_status(self, source_type: DataSourceType) -> SourceStatus:
|
||||
"""
|
||||
Check the status of a specific data source.
|
||||
|
||||
Args:
|
||||
source_type: The type of source to check
|
||||
|
||||
Returns:
|
||||
SourceStatus with current availability information
|
||||
"""
|
||||
if source_type == DataSourceType.CACHE:
|
||||
return self._check_cache_status()
|
||||
elif source_type == DataSourceType.SNOWFLAKE:
|
||||
return self._check_snowflake_status()
|
||||
elif source_type == DataSourceType.SQLITE:
|
||||
return self._check_sqlite_status()
|
||||
elif source_type == DataSourceType.FILE:
|
||||
return self._check_file_status()
|
||||
else:
|
||||
return SourceStatus(
|
||||
source_type=source_type,
|
||||
available=False,
|
||||
configured=False,
|
||||
message=f"Unknown source type: {source_type}",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
def check_all_sources(self) -> list[SourceStatus]:
|
||||
"""
|
||||
Check the status of all data sources.
|
||||
|
||||
Returns:
|
||||
List of SourceStatus for each source type
|
||||
"""
|
||||
statuses = []
|
||||
for source_type in DataSourceType:
|
||||
status = self.check_source_status(source_type)
|
||||
self._source_status[source_type] = status
|
||||
statuses.append(status)
|
||||
return statuses
|
||||
|
||||
def _build_cache_key_params(
|
||||
self,
|
||||
start_date: Optional[date],
|
||||
end_date: Optional[date],
|
||||
trusts: Optional[list[str]],
|
||||
drugs: Optional[list[str]],
|
||||
directories: Optional[list[str]],
|
||||
) -> tuple[str, tuple]:
|
||||
"""Build a cache-compatible query string and params for the filter criteria."""
|
||||
# Create a canonical representation for caching
|
||||
query_parts = ["SELECT * FROM activity_data"]
|
||||
params = []
|
||||
|
||||
conditions = []
|
||||
if start_date:
|
||||
conditions.append("start_date >= ?")
|
||||
params.append(str(start_date))
|
||||
if end_date:
|
||||
conditions.append("end_date <= ?")
|
||||
params.append(str(end_date))
|
||||
if trusts:
|
||||
placeholders = ",".join(["?"] * len(trusts))
|
||||
conditions.append(f"trust IN ({placeholders})")
|
||||
params.extend(sorted(trusts))
|
||||
if drugs:
|
||||
placeholders = ",".join(["?"] * len(drugs))
|
||||
conditions.append(f"drug IN ({placeholders})")
|
||||
params.extend(sorted(drugs))
|
||||
if directories:
|
||||
placeholders = ",".join(["?"] * len(directories))
|
||||
conditions.append(f"directory IN ({placeholders})")
|
||||
params.extend(sorted(directories))
|
||||
|
||||
if conditions:
|
||||
query_parts.append("WHERE " + " AND ".join(conditions))
|
||||
|
||||
query = " ".join(query_parts)
|
||||
return query, tuple(params)
|
||||
|
||||
def _try_cache(
|
||||
self,
|
||||
start_date: Optional[date],
|
||||
end_date: Optional[date],
|
||||
trusts: Optional[list[str]],
|
||||
drugs: Optional[list[str]],
|
||||
directories: Optional[list[str]],
|
||||
) -> Optional[DataSourceResult]:
|
||||
"""Try to get data from cache."""
|
||||
if not self._cache_enabled:
|
||||
return None
|
||||
|
||||
try:
|
||||
from data_processing.cache import get_cache
|
||||
|
||||
cache = get_cache()
|
||||
if not cache.is_enabled:
|
||||
return None
|
||||
|
||||
query, params = self._build_cache_key_params(
|
||||
start_date, end_date, trusts, drugs, directories
|
||||
)
|
||||
|
||||
cached_data = cache.get(query, params)
|
||||
if cached_data is None:
|
||||
logger.debug("Cache miss")
|
||||
return None
|
||||
|
||||
# Convert cached data back to DataFrame
|
||||
df = pd.DataFrame(cached_data)
|
||||
|
||||
# Convert date columns
|
||||
if 'Intervention Date' in df.columns:
|
||||
df['Intervention Date'] = pd.to_datetime(df['Intervention Date'])
|
||||
|
||||
logger.info(f"Cache hit: {len(df)} rows")
|
||||
|
||||
return DataSourceResult(
|
||||
df=df,
|
||||
source_type=DataSourceType.CACHE,
|
||||
source_detail=f"cache_key={query[:50]}...",
|
||||
row_count=len(df),
|
||||
cached=True,
|
||||
from_fallback=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache lookup failed: {e}")
|
||||
return None
|
||||
|
||||
def _try_snowflake(
|
||||
self,
|
||||
start_date: Optional[date],
|
||||
end_date: Optional[date],
|
||||
trusts: Optional[list[str]],
|
||||
drugs: Optional[list[str]],
|
||||
directories: Optional[list[str]],
|
||||
progress_callback: Optional[Callable[[int, int], None]] = None,
|
||||
) -> Optional[DataSourceResult]:
|
||||
"""Try to get data from Snowflake."""
|
||||
import time
|
||||
|
||||
try:
|
||||
from data_processing.snowflake_connector import (
|
||||
is_snowflake_available,
|
||||
is_snowflake_configured,
|
||||
get_connector,
|
||||
SnowflakeConnectionError,
|
||||
)
|
||||
|
||||
if not is_snowflake_available():
|
||||
logger.debug("Snowflake connector not installed")
|
||||
return None
|
||||
|
||||
if not is_snowflake_configured():
|
||||
logger.debug("Snowflake not configured")
|
||||
return None
|
||||
|
||||
# Get connector and fetch data
|
||||
connector = get_connector()
|
||||
logger.info("Fetching data from Snowflake...")
|
||||
start_time = time.time()
|
||||
|
||||
# Fetch activity data from Snowflake
|
||||
# Note: provider_codes filter not directly supported yet - would need trust name to code mapping
|
||||
rows = connector.fetch_activity_data(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
provider_codes=None, # TODO: map trust names to provider codes if needed
|
||||
)
|
||||
|
||||
if not rows:
|
||||
logger.warning("Snowflake returned no data")
|
||||
return None
|
||||
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(rows)
|
||||
load_time = time.time() - start_time
|
||||
|
||||
logger.info(f"Snowflake loaded {len(df)} rows in {load_time:.2f}s")
|
||||
|
||||
# Apply local transformations to match expected format
|
||||
# (patient_id, drug_names, department_identification)
|
||||
from tools.data import patient_id, drug_names, department_identification
|
||||
from core import default_paths
|
||||
|
||||
df = patient_id(df)
|
||||
df = drug_names(df, paths=default_paths)
|
||||
df = department_identification(df, paths=default_paths)
|
||||
|
||||
# Apply additional filters if provided
|
||||
if trusts and 'OrganisationName' in df.columns:
|
||||
df = df[df['OrganisationName'].isin(trusts)]
|
||||
if drugs and 'Drug Name' in df.columns:
|
||||
df = df[df['Drug Name'].isin(drugs)]
|
||||
if directories and 'Directory' in df.columns:
|
||||
df = df[df['Directory'].isin(directories)]
|
||||
|
||||
return DataSourceResult(
|
||||
df=df,
|
||||
source_type=DataSourceType.SNOWFLAKE,
|
||||
source_detail="DATA_HUB.CDM.Acute__Conmon__PatientLevelDrugs",
|
||||
row_count=len(df),
|
||||
cached=False,
|
||||
from_fallback=False,
|
||||
load_time_seconds=load_time,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Snowflake query failed: {e}")
|
||||
return None
|
||||
|
||||
def _try_sqlite(
|
||||
self,
|
||||
start_date: Optional[date],
|
||||
end_date: Optional[date],
|
||||
trusts: Optional[list[str]],
|
||||
drugs: Optional[list[str]],
|
||||
directories: Optional[list[str]],
|
||||
) -> Optional[DataSourceResult]:
|
||||
"""Try to get data from SQLite."""
|
||||
import time
|
||||
|
||||
try:
|
||||
from data_processing.loader import SQLiteDataLoader
|
||||
|
||||
# Determine database path
|
||||
db_path = self._sqlite_db_path
|
||||
if db_path is None:
|
||||
from data_processing.database import default_db_config
|
||||
db_path = Path(default_db_config.db_path)
|
||||
|
||||
loader = SQLiteDataLoader(
|
||||
db_path=db_path,
|
||||
date_range=(start_date, end_date) if start_date and end_date else None,
|
||||
trusts=trusts,
|
||||
drugs=drugs,
|
||||
directories=directories,
|
||||
)
|
||||
|
||||
# Check if source is valid
|
||||
is_valid, msg = loader.validate_source()
|
||||
if not is_valid:
|
||||
logger.debug(f"SQLite not available: {msg}")
|
||||
return None
|
||||
|
||||
start_time = time.time()
|
||||
result = loader.load()
|
||||
load_time = time.time() - start_time
|
||||
|
||||
logger.info(f"SQLite loaded {result.row_count} rows in {load_time:.2f}s")
|
||||
|
||||
return DataSourceResult(
|
||||
df=result.df,
|
||||
source_type=DataSourceType.SQLITE,
|
||||
source_detail=str(db_path),
|
||||
row_count=result.row_count,
|
||||
cached=False,
|
||||
from_fallback=False,
|
||||
load_time_seconds=load_time,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"SQLite query failed: {e}")
|
||||
return None
|
||||
|
||||
def _try_file(
|
||||
self,
|
||||
start_date: Optional[date],
|
||||
end_date: Optional[date],
|
||||
trusts: Optional[list[str]],
|
||||
drugs: Optional[list[str]],
|
||||
directories: Optional[list[str]],
|
||||
) -> Optional[DataSourceResult]:
|
||||
"""Try to get data from local file."""
|
||||
import time
|
||||
|
||||
if self._local_file_path is None:
|
||||
logger.debug("No local file configured")
|
||||
return None
|
||||
|
||||
try:
|
||||
from data_processing.loader import FileDataLoader
|
||||
|
||||
loader = FileDataLoader(file_path=self._local_file_path)
|
||||
|
||||
is_valid, msg = loader.validate_source()
|
||||
if not is_valid:
|
||||
logger.debug(f"Local file not available: {msg}")
|
||||
return None
|
||||
|
||||
start_time = time.time()
|
||||
result = loader.load()
|
||||
df = result.df
|
||||
|
||||
# Apply filters (file loader loads all data, then we filter)
|
||||
if start_date and 'Intervention Date' in df.columns:
|
||||
df = df[df['Intervention Date'] >= pd.Timestamp(start_date)]
|
||||
if end_date and 'Intervention Date' in df.columns:
|
||||
df = df[df['Intervention Date'] < pd.Timestamp(end_date)]
|
||||
if trusts and 'OrganisationName' in df.columns:
|
||||
df = df[df['OrganisationName'].isin(trusts)]
|
||||
if drugs and 'Drug Name' in df.columns:
|
||||
df = df[df['Drug Name'].isin(drugs)]
|
||||
if directories and 'Directory' in df.columns:
|
||||
df = df[df['Directory'].isin(directories)]
|
||||
|
||||
load_time = time.time() - start_time
|
||||
|
||||
logger.info(f"File loaded and filtered: {len(df)} rows in {load_time:.2f}s")
|
||||
|
||||
return DataSourceResult(
|
||||
df=df,
|
||||
source_type=DataSourceType.FILE,
|
||||
source_detail=str(self._local_file_path),
|
||||
row_count=len(df),
|
||||
cached=False,
|
||||
from_fallback=True,
|
||||
load_time_seconds=load_time,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"File load failed: {e}")
|
||||
return None
|
||||
|
||||
def get_data(
|
||||
self,
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
trusts: Optional[list[str]] = None,
|
||||
drugs: Optional[list[str]] = None,
|
||||
directories: Optional[list[str]] = None,
|
||||
preferred_source: Optional[str] = None,
|
||||
skip_cache: bool = False,
|
||||
progress_callback: Optional[Callable[[int, int], None]] = None,
|
||||
) -> DataSourceResult:
|
||||
"""
|
||||
Get patient intervention data from the best available source.
|
||||
|
||||
The fallback chain is: Cache → Snowflake → SQLite → File
|
||||
|
||||
Args:
|
||||
start_date: Optional start date for filtering (inclusive)
|
||||
end_date: Optional end date for filtering (exclusive)
|
||||
trusts: Optional list of trust names to filter
|
||||
drugs: Optional list of drug names to filter
|
||||
directories: Optional list of directories to filter
|
||||
preferred_source: Optional preferred source ("snowflake", "sqlite", "file")
|
||||
skip_cache: If True, bypass cache and query source directly
|
||||
progress_callback: Optional callback(current, total) for progress updates
|
||||
|
||||
Returns:
|
||||
DataSourceResult with the loaded data and metadata
|
||||
|
||||
Raises:
|
||||
ValueError: If no data source is available or all sources fail
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
warnings = []
|
||||
|
||||
# If preferred source specified, try that first
|
||||
if preferred_source:
|
||||
preferred = preferred_source.lower()
|
||||
if preferred == "snowflake":
|
||||
result = self._try_snowflake(
|
||||
start_date, end_date, trusts, drugs, directories, progress_callback
|
||||
)
|
||||
if result:
|
||||
result.load_time_seconds = time.time() - start_time
|
||||
return result
|
||||
warnings.append("Preferred source 'snowflake' unavailable")
|
||||
|
||||
elif preferred == "sqlite":
|
||||
result = self._try_sqlite(
|
||||
start_date, end_date, trusts, drugs, directories
|
||||
)
|
||||
if result:
|
||||
result.load_time_seconds = time.time() - start_time
|
||||
return result
|
||||
warnings.append("Preferred source 'sqlite' unavailable")
|
||||
|
||||
elif preferred == "file":
|
||||
result = self._try_file(
|
||||
start_date, end_date, trusts, drugs, directories
|
||||
)
|
||||
if result:
|
||||
result.load_time_seconds = time.time() - start_time
|
||||
return result
|
||||
warnings.append("Preferred source 'file' unavailable")
|
||||
|
||||
# Standard fallback chain: cache → snowflake → sqlite → file
|
||||
|
||||
# 1. Try cache first (unless skipped)
|
||||
if not skip_cache:
|
||||
result = self._try_cache(
|
||||
start_date, end_date, trusts, drugs, directories
|
||||
)
|
||||
if result:
|
||||
result.load_time_seconds = time.time() - start_time
|
||||
return result
|
||||
|
||||
# 2. Try Snowflake
|
||||
result = self._try_snowflake(
|
||||
start_date, end_date, trusts, drugs, directories, progress_callback
|
||||
)
|
||||
if result:
|
||||
# Cache the result for future queries
|
||||
if self._cache_enabled:
|
||||
self._cache_result(
|
||||
result.df,
|
||||
start_date, end_date, trusts, drugs, directories,
|
||||
includes_current_data=end_date is None or end_date >= date.today()
|
||||
)
|
||||
result.load_time_seconds = time.time() - start_time
|
||||
return result
|
||||
|
||||
# 3. Try SQLite
|
||||
result = self._try_sqlite(
|
||||
start_date, end_date, trusts, drugs, directories
|
||||
)
|
||||
if result:
|
||||
result.from_fallback = True # Mark as fallback since Snowflake wasn't used
|
||||
result.load_time_seconds = time.time() - start_time
|
||||
if warnings:
|
||||
result.warnings.extend(warnings)
|
||||
return result
|
||||
|
||||
# 4. Try local file
|
||||
result = self._try_file(
|
||||
start_date, end_date, trusts, drugs, directories
|
||||
)
|
||||
if result:
|
||||
result.from_fallback = True
|
||||
result.load_time_seconds = time.time() - start_time
|
||||
if warnings:
|
||||
result.warnings.extend(warnings)
|
||||
return result
|
||||
|
||||
# All sources failed
|
||||
source_status = self.check_all_sources()
|
||||
status_msg = "; ".join(
|
||||
f"{s.source_type.value}: {s.message}" for s in source_status
|
||||
)
|
||||
raise ValueError(f"No data source available. Status: {status_msg}")
|
||||
|
||||
def _cache_result(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
start_date: Optional[date],
|
||||
end_date: Optional[date],
|
||||
trusts: Optional[list[str]],
|
||||
drugs: Optional[list[str]],
|
||||
directories: Optional[list[str]],
|
||||
includes_current_data: bool = False,
|
||||
) -> bool:
|
||||
"""Cache a query result for future use."""
|
||||
try:
|
||||
from data_processing.cache import get_cache
|
||||
|
||||
cache = get_cache()
|
||||
if not cache.is_enabled:
|
||||
return False
|
||||
|
||||
query, params = self._build_cache_key_params(
|
||||
start_date, end_date, trusts, drugs, directories
|
||||
)
|
||||
|
||||
# Convert DataFrame to list of dicts for caching
|
||||
# Convert datetime columns to strings for JSON serialization
|
||||
df_copy = df.copy()
|
||||
for col in df_copy.columns:
|
||||
if pd.api.types.is_datetime64_any_dtype(df_copy[col]):
|
||||
df_copy[col] = df_copy[col].astype(str)
|
||||
|
||||
data = df_copy.to_dict(orient='records')
|
||||
|
||||
entry = cache.set(
|
||||
query, params, data,
|
||||
includes_current_data=includes_current_data
|
||||
)
|
||||
|
||||
if entry:
|
||||
logger.info(f"Cached {len(data)} rows (key={entry.cache_key[:16]}...)")
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache result: {e}")
|
||||
return False
|
||||
|
||||
def clear_cache(self) -> int:
|
||||
"""
|
||||
Clear all cached data.
|
||||
|
||||
Returns:
|
||||
Number of cache entries cleared
|
||||
"""
|
||||
try:
|
||||
from data_processing.cache import get_cache
|
||||
cache = get_cache()
|
||||
return cache.clear()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clear cache: {e}")
|
||||
return 0
|
||||
|
||||
def refresh_from_snowflake(
|
||||
self,
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
trusts: Optional[list[str]] = None,
|
||||
drugs: Optional[list[str]] = None,
|
||||
directories: Optional[list[str]] = None,
|
||||
progress_callback: Optional[Callable[[int, int], None]] = None,
|
||||
) -> DataSourceResult:
|
||||
"""
|
||||
Force a refresh from Snowflake, bypassing cache and other sources.
|
||||
|
||||
This method specifically queries Snowflake and will fail if Snowflake
|
||||
is not available or not configured.
|
||||
|
||||
Args:
|
||||
start_date: Optional start date for filtering
|
||||
end_date: Optional end date for filtering
|
||||
trusts: Optional list of trust names
|
||||
drugs: Optional list of drug names
|
||||
directories: Optional list of directories
|
||||
progress_callback: Optional progress callback
|
||||
|
||||
Returns:
|
||||
DataSourceResult from Snowflake
|
||||
|
||||
Raises:
|
||||
ValueError: If Snowflake is not available or query fails
|
||||
"""
|
||||
from data_processing.snowflake_connector import (
|
||||
is_snowflake_available,
|
||||
is_snowflake_configured,
|
||||
)
|
||||
|
||||
if not is_snowflake_available():
|
||||
raise ValueError("Snowflake connector not installed")
|
||||
|
||||
if not is_snowflake_configured():
|
||||
raise ValueError("Snowflake not configured - edit config/snowflake.toml")
|
||||
|
||||
result = self._try_snowflake(
|
||||
start_date, end_date, trusts, drugs, directories, progress_callback
|
||||
)
|
||||
|
||||
if result is None:
|
||||
raise ValueError("Snowflake query failed - check logs for details")
|
||||
|
||||
# Cache the fresh result
|
||||
if self._cache_enabled:
|
||||
self._cache_result(
|
||||
result.df,
|
||||
start_date, end_date, trusts, drugs, directories,
|
||||
includes_current_data=end_date is None or end_date >= date.today()
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# Module-level singleton and convenience functions
|
||||
_default_manager: Optional[DataSourceManager] = None
|
||||
|
||||
|
||||
def get_data_source_manager(
|
||||
cache_enabled: bool = True,
|
||||
local_file_path: Optional[Path | str] = None,
|
||||
sqlite_db_path: Optional[Path | str] = None,
|
||||
) -> DataSourceManager:
|
||||
"""
|
||||
Get a DataSourceManager instance.
|
||||
|
||||
Args:
|
||||
cache_enabled: Whether to enable caching
|
||||
local_file_path: Optional path to local CSV/Parquet file
|
||||
sqlite_db_path: Optional path to SQLite database
|
||||
|
||||
Returns:
|
||||
DataSourceManager instance
|
||||
"""
|
||||
global _default_manager
|
||||
|
||||
# If custom paths provided, create a new manager
|
||||
if local_file_path or sqlite_db_path:
|
||||
return DataSourceManager(
|
||||
cache_enabled=cache_enabled,
|
||||
local_file_path=local_file_path,
|
||||
sqlite_db_path=sqlite_db_path,
|
||||
)
|
||||
|
||||
# Otherwise use/create singleton
|
||||
if _default_manager is None:
|
||||
_default_manager = DataSourceManager(cache_enabled=cache_enabled)
|
||||
|
||||
return _default_manager
|
||||
|
||||
|
||||
def get_data(
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
trusts: Optional[list[str]] = None,
|
||||
drugs: Optional[list[str]] = None,
|
||||
directories: Optional[list[str]] = None,
|
||||
preferred_source: Optional[str] = None,
|
||||
skip_cache: bool = False,
|
||||
) -> DataSourceResult:
|
||||
"""
|
||||
Convenience function to get data using the default manager.
|
||||
|
||||
Args:
|
||||
start_date: Optional start date for filtering
|
||||
end_date: Optional end date for filtering
|
||||
trusts: Optional list of trust names
|
||||
drugs: Optional list of drug names
|
||||
directories: Optional list of directories
|
||||
preferred_source: Optional preferred source
|
||||
skip_cache: If True, bypass cache
|
||||
|
||||
Returns:
|
||||
DataSourceResult with loaded data
|
||||
"""
|
||||
manager = get_data_source_manager()
|
||||
return manager.get_data(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
trusts=trusts,
|
||||
drugs=drugs,
|
||||
directories=directories,
|
||||
preferred_source=preferred_source,
|
||||
skip_cache=skip_cache,
|
||||
)
|
||||
|
||||
|
||||
def reset_data_source_manager() -> None:
|
||||
"""Reset the default data source manager singleton."""
|
||||
global _default_manager
|
||||
_default_manager = None
|
||||
|
||||
|
||||
# Export public API
|
||||
__all__ = [
|
||||
"DataSourceType",
|
||||
"DataSourceResult",
|
||||
"SourceStatus",
|
||||
"DataSourceManager",
|
||||
"get_data_source_manager",
|
||||
"get_data",
|
||||
"reset_data_source_manager",
|
||||
]
|
||||
@@ -0,0 +1,239 @@
|
||||
"""
|
||||
SQLite database connection management for NHS High-Cost Drug Patient Pathway Analysis Tool.
|
||||
|
||||
Provides connection management, schema initialization, and common database operations.
|
||||
Uses context manager pattern for safe resource handling.
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Optional, Generator, Literal
|
||||
|
||||
from core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DatabaseConfig:
|
||||
"""
|
||||
Configuration for SQLite database location and connection parameters.
|
||||
|
||||
Attributes:
|
||||
db_path: Path to the SQLite database file
|
||||
timeout: Connection timeout in seconds (default: 30)
|
||||
isolation_level: Transaction isolation level (default: None for autocommit)
|
||||
"""
|
||||
|
||||
DEFAULT_DB_NAME = "pathways.db"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: Optional[Path] = None,
|
||||
data_dir: Optional[Path] = None,
|
||||
timeout: float = 30.0,
|
||||
isolation_level: Optional[Literal['DEFERRED', 'EXCLUSIVE', 'IMMEDIATE']] = None
|
||||
):
|
||||
"""
|
||||
Initialize database configuration.
|
||||
|
||||
Args:
|
||||
db_path: Full path to database file. If None, uses data_dir/DEFAULT_DB_NAME.
|
||||
data_dir: Directory to place database in. Defaults to ./data/
|
||||
timeout: Connection timeout in seconds.
|
||||
isolation_level: Transaction isolation level. None = autocommit.
|
||||
"""
|
||||
if db_path is not None:
|
||||
self.db_path = Path(db_path)
|
||||
elif data_dir is not None:
|
||||
self.db_path = Path(data_dir) / self.DEFAULT_DB_NAME
|
||||
else:
|
||||
self.db_path = Path("./data") / self.DEFAULT_DB_NAME
|
||||
|
||||
self.timeout = timeout
|
||||
self.isolation_level = isolation_level
|
||||
|
||||
def validate(self) -> list[str]:
|
||||
"""
|
||||
Validate database configuration.
|
||||
|
||||
Returns:
|
||||
List of error messages. Empty list means configuration is valid.
|
||||
"""
|
||||
errors = []
|
||||
|
||||
# Check parent directory exists
|
||||
parent_dir = self.db_path.parent
|
||||
if not parent_dir.exists():
|
||||
errors.append(f"Database directory does not exist: {parent_dir}")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""
|
||||
Manages SQLite database connections and operations.
|
||||
|
||||
Provides context manager for safe connection handling and methods
|
||||
for common database operations.
|
||||
|
||||
Usage:
|
||||
db_manager = DatabaseManager()
|
||||
|
||||
# Using context manager (recommended)
|
||||
with db_manager.get_connection() as conn:
|
||||
cursor = conn.execute("SELECT * FROM ref_drug_names")
|
||||
results = cursor.fetchall()
|
||||
|
||||
# Or get a managed connection for longer operations
|
||||
conn = db_manager.connect()
|
||||
try:
|
||||
# ... do work ...
|
||||
finally:
|
||||
conn.close()
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[DatabaseConfig] = None):
|
||||
"""
|
||||
Initialize the database manager.
|
||||
|
||||
Args:
|
||||
config: Database configuration. If None, uses default configuration.
|
||||
"""
|
||||
self.config = config or DatabaseConfig()
|
||||
self._connection: Optional[sqlite3.Connection] = None
|
||||
|
||||
@property
|
||||
def db_path(self) -> Path:
|
||||
"""Path to the SQLite database file."""
|
||||
return self.config.db_path
|
||||
|
||||
@property
|
||||
def exists(self) -> bool:
|
||||
"""Check if the database file exists."""
|
||||
return self.db_path.exists()
|
||||
|
||||
def connect(self) -> sqlite3.Connection:
|
||||
"""
|
||||
Create a new database connection.
|
||||
|
||||
Returns:
|
||||
sqlite3.Connection: New database connection.
|
||||
|
||||
Note:
|
||||
The caller is responsible for closing the connection.
|
||||
Consider using get_connection() context manager instead.
|
||||
"""
|
||||
conn = sqlite3.connect(
|
||||
str(self.db_path),
|
||||
timeout=self.config.timeout,
|
||||
isolation_level=self.config.isolation_level
|
||||
)
|
||||
# Enable foreign key support
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
# Return rows as sqlite3.Row for dict-like access
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self) -> Generator[sqlite3.Connection, None, None]:
|
||||
"""
|
||||
Context manager for database connections.
|
||||
|
||||
Yields:
|
||||
sqlite3.Connection: Database connection.
|
||||
|
||||
Example:
|
||||
with db_manager.get_connection() as conn:
|
||||
conn.execute("INSERT INTO table VALUES (?)", (value,))
|
||||
conn.commit()
|
||||
"""
|
||||
conn = self.connect()
|
||||
try:
|
||||
yield conn
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
@contextmanager
|
||||
def get_transaction(self) -> Generator[sqlite3.Connection, None, None]:
|
||||
"""
|
||||
Context manager for transactional operations.
|
||||
|
||||
Automatically commits on success, rolls back on exception.
|
||||
|
||||
Yields:
|
||||
sqlite3.Connection: Database connection in transaction mode.
|
||||
|
||||
Example:
|
||||
with db_manager.get_transaction() as conn:
|
||||
conn.execute("INSERT INTO table VALUES (?)", (value1,))
|
||||
conn.execute("INSERT INTO other_table VALUES (?)", (value2,))
|
||||
# Auto-commits if no exception
|
||||
"""
|
||||
conn = sqlite3.connect(
|
||||
str(self.db_path),
|
||||
timeout=self.config.timeout,
|
||||
isolation_level="DEFERRED" # Explicit transaction mode
|
||||
)
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
yield conn
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def execute_script(self, sql_script: str) -> None:
|
||||
"""
|
||||
Execute a SQL script (multiple statements).
|
||||
|
||||
Args:
|
||||
sql_script: SQL script containing one or more statements.
|
||||
"""
|
||||
with self.get_connection() as conn:
|
||||
conn.executescript(sql_script)
|
||||
logger.info("Executed SQL script successfully")
|
||||
|
||||
def table_exists(self, table_name: str) -> bool:
|
||||
"""
|
||||
Check if a table exists in the database.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table to check.
|
||||
|
||||
Returns:
|
||||
True if the table exists, False otherwise.
|
||||
"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
||||
(table_name,)
|
||||
)
|
||||
return cursor.fetchone() is not None
|
||||
|
||||
def get_table_count(self, table_name: str) -> int:
|
||||
"""
|
||||
Get the row count for a table.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table.
|
||||
|
||||
Returns:
|
||||
Number of rows in the table.
|
||||
"""
|
||||
with self.get_connection() as conn:
|
||||
# Use parameterized table name via string formatting (safe since we control table_name)
|
||||
cursor = conn.execute(f"SELECT COUNT(*) FROM {table_name}")
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else 0
|
||||
|
||||
|
||||
# Default instance for application-wide use
|
||||
default_db_config = DatabaseConfig()
|
||||
default_db_manager = DatabaseManager(default_db_config)
|
||||
@@ -0,0 +1,581 @@
|
||||
"""
|
||||
Diagnosis lookup module for NHS Patient Pathway Analysis.
|
||||
|
||||
Provides functions to validate patient indications by checking GP diagnosis records
|
||||
against SNOMED cluster codes. Uses the drug-to-cluster mapping from
|
||||
drug_indication_clusters.csv and queries Snowflake for SNOMED codes and GP records.
|
||||
|
||||
Key workflow:
|
||||
1. Get drug's valid indication clusters from local mapping
|
||||
2. Get all SNOMED codes for those clusters from Snowflake
|
||||
3. Check if patient has any of those SNOMED codes in GP records
|
||||
4. Report indication validation status
|
||||
|
||||
IMPORTANT: HCD activity data indication codes are UNRELIABLE. This module uses
|
||||
GP/Primary Care data (PrimaryCareClinicalCoding) as the authoritative source.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date, datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Callable, Any, cast
|
||||
import csv
|
||||
|
||||
from core.logging_config import get_logger
|
||||
from data_processing.database import DatabaseManager, default_db_manager
|
||||
from data_processing.snowflake_connector import (
|
||||
SnowflakeConnector,
|
||||
get_connector,
|
||||
is_snowflake_available,
|
||||
is_snowflake_configured,
|
||||
SNOWFLAKE_AVAILABLE,
|
||||
)
|
||||
from data_processing.cache import get_cache, is_cache_enabled
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClusterSnomedCodes:
|
||||
"""SNOMED codes for a clinical coding cluster."""
|
||||
cluster_id: str
|
||||
cluster_description: str
|
||||
snomed_codes: list[str] = field(default_factory=list)
|
||||
snomed_descriptions: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def code_count(self) -> int:
|
||||
return len(self.snomed_codes)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndicationValidationResult:
|
||||
"""Result of validating a patient's indication for a drug."""
|
||||
patient_pseudonym: str
|
||||
drug_name: str
|
||||
has_valid_indication: bool
|
||||
matched_cluster_id: Optional[str] = None
|
||||
matched_snomed_code: Optional[str] = None
|
||||
matched_snomed_description: Optional[str] = None
|
||||
checked_clusters: list[str] = field(default_factory=list)
|
||||
total_codes_checked: int = 0
|
||||
source: str = "GP_SNOMED" # GP_SNOMED | NONE
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DrugIndicationMatchRate:
|
||||
"""Match rate statistics for a drug's indication validation."""
|
||||
drug_name: str
|
||||
total_patients: int
|
||||
patients_with_indication: int
|
||||
patients_without_indication: int
|
||||
match_rate: float # 0.0 to 1.0
|
||||
clusters_checked: list[str] = field(default_factory=list)
|
||||
sample_unmatched: list[str] = field(default_factory=list) # Sample patient IDs
|
||||
|
||||
|
||||
def get_drug_clusters(
|
||||
drug_name: str,
|
||||
db_manager: Optional[DatabaseManager] = None
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Get all SNOMED cluster mappings for a drug from local SQLite.
|
||||
|
||||
Args:
|
||||
drug_name: Drug name to look up (case-insensitive)
|
||||
db_manager: Optional DatabaseManager (defaults to default_db_manager)
|
||||
|
||||
Returns:
|
||||
List of dicts with keys: drug_name, indication, cluster_id,
|
||||
cluster_description, nice_ta_reference
|
||||
"""
|
||||
if db_manager is None:
|
||||
db_manager = default_db_manager
|
||||
|
||||
query = """
|
||||
SELECT drug_name, indication, cluster_id, cluster_description, nice_ta_reference
|
||||
FROM ref_drug_indication_clusters
|
||||
WHERE UPPER(drug_name) = UPPER(?)
|
||||
ORDER BY indication, cluster_id
|
||||
"""
|
||||
|
||||
try:
|
||||
with db_manager.get_connection() as conn:
|
||||
cursor = conn.execute(query, (drug_name,))
|
||||
rows = cursor.fetchall()
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
results.append({
|
||||
"drug_name": row["drug_name"],
|
||||
"indication": row["indication"],
|
||||
"cluster_id": row["cluster_id"],
|
||||
"cluster_description": row["cluster_description"],
|
||||
"nice_ta_reference": row["nice_ta_reference"],
|
||||
})
|
||||
|
||||
logger.debug(f"Found {len(results)} cluster mappings for drug '{drug_name}'")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting clusters for drug '{drug_name}': {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_drug_cluster_ids(
|
||||
drug_name: str,
|
||||
db_manager: Optional[DatabaseManager] = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Get unique cluster IDs for a drug.
|
||||
|
||||
Args:
|
||||
drug_name: Drug name to look up
|
||||
db_manager: Optional DatabaseManager
|
||||
|
||||
Returns:
|
||||
List of unique cluster IDs
|
||||
"""
|
||||
clusters = get_drug_clusters(drug_name, db_manager)
|
||||
return list(set(c["cluster_id"] for c in clusters))
|
||||
|
||||
|
||||
def get_cluster_snomed_codes(
|
||||
cluster_id: str,
|
||||
connector: Optional[SnowflakeConnector] = None,
|
||||
use_cache: bool = True,
|
||||
) -> ClusterSnomedCodes:
|
||||
"""
|
||||
Get all SNOMED codes for a cluster from Snowflake.
|
||||
|
||||
Queries the ClinicalCodingClusterSnomedCodes table to get all SNOMED codes
|
||||
that belong to the specified cluster.
|
||||
|
||||
Args:
|
||||
cluster_id: Cluster ID to look up (e.g., 'RARTH_COD', 'PSORIASIS_COD')
|
||||
connector: Optional SnowflakeConnector (defaults to singleton)
|
||||
use_cache: Whether to use cached results (default True)
|
||||
|
||||
Returns:
|
||||
ClusterSnomedCodes with list of SNOMED codes and descriptions
|
||||
"""
|
||||
if not SNOWFLAKE_AVAILABLE:
|
||||
logger.warning("Snowflake connector not available")
|
||||
return ClusterSnomedCodes(cluster_id=cluster_id, cluster_description="")
|
||||
|
||||
if not is_snowflake_configured():
|
||||
logger.warning("Snowflake not configured - cannot get cluster codes")
|
||||
return ClusterSnomedCodes(cluster_id=cluster_id, cluster_description="")
|
||||
|
||||
# Check cache first
|
||||
cache_key = f"cluster_snomed_{cluster_id}"
|
||||
if use_cache and is_cache_enabled():
|
||||
cache = get_cache()
|
||||
cached = cache.get(cache_key)
|
||||
if cached is not None and len(cached) > 0:
|
||||
logger.debug(f"Using cached SNOMED codes for cluster '{cluster_id}'")
|
||||
cached_dict = cached[0] # First element is our data dict
|
||||
return ClusterSnomedCodes(
|
||||
cluster_id=cluster_id,
|
||||
cluster_description=str(cached_dict.get("description", "")),
|
||||
snomed_codes=list(cached_dict.get("codes", [])),
|
||||
snomed_descriptions=dict(cached_dict.get("descriptions", {})),
|
||||
)
|
||||
|
||||
if connector is None:
|
||||
connector = get_connector()
|
||||
|
||||
query = '''
|
||||
SELECT DISTINCT
|
||||
"Cluster_ID",
|
||||
"Cluster_Description",
|
||||
"SNOMEDCode",
|
||||
"SNOMEDDescription"
|
||||
FROM DATA_HUB.PHM."ClinicalCodingClusterSnomedCodes"
|
||||
WHERE "Cluster_ID" = %s
|
||||
ORDER BY "SNOMEDCode"
|
||||
'''
|
||||
|
||||
try:
|
||||
results = connector.execute_dict(query, (cluster_id,))
|
||||
|
||||
if not results:
|
||||
logger.warning(f"No SNOMED codes found for cluster '{cluster_id}'")
|
||||
return ClusterSnomedCodes(cluster_id=cluster_id, cluster_description="")
|
||||
|
||||
codes = []
|
||||
descriptions = {}
|
||||
description = results[0].get("Cluster_Description", "") if results else ""
|
||||
|
||||
for row in results:
|
||||
code = row.get("SNOMEDCode")
|
||||
if code:
|
||||
codes.append(code)
|
||||
descriptions[code] = row.get("SNOMEDDescription", "")
|
||||
|
||||
logger.info(f"Found {len(codes)} SNOMED codes for cluster '{cluster_id}'")
|
||||
|
||||
# Cache the results (using query-based cache with fake params)
|
||||
if use_cache and is_cache_enabled():
|
||||
cache = get_cache()
|
||||
cache_data = [{
|
||||
"description": description,
|
||||
"codes": codes,
|
||||
"descriptions": descriptions,
|
||||
}]
|
||||
cache.set(cache_key, None, cache_data) # type: ignore[arg-type]
|
||||
|
||||
return ClusterSnomedCodes(
|
||||
cluster_id=cluster_id,
|
||||
cluster_description=description,
|
||||
snomed_codes=codes,
|
||||
snomed_descriptions=descriptions,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting SNOMED codes for cluster '{cluster_id}': {e}")
|
||||
return ClusterSnomedCodes(cluster_id=cluster_id, cluster_description="")
|
||||
|
||||
|
||||
def patient_has_indication(
|
||||
patient_pseudonym: str,
|
||||
cluster_ids: list[str],
|
||||
connector: Optional[SnowflakeConnector] = None,
|
||||
before_date: Optional[date] = None,
|
||||
) -> tuple[bool, Optional[str], Optional[str], Optional[str]]:
|
||||
"""
|
||||
Check if a patient has any SNOMED codes from the specified clusters in GP records.
|
||||
|
||||
Args:
|
||||
patient_pseudonym: Patient's pseudonymised NHS number
|
||||
cluster_ids: List of cluster IDs to check against
|
||||
connector: Optional SnowflakeConnector
|
||||
before_date: Optional date - only check diagnoses before this date
|
||||
|
||||
Returns:
|
||||
Tuple of (has_indication, matched_cluster_id, matched_snomed_code, matched_description)
|
||||
"""
|
||||
if not SNOWFLAKE_AVAILABLE or not is_snowflake_configured():
|
||||
return False, None, None, None
|
||||
|
||||
if not cluster_ids:
|
||||
return False, None, None, None
|
||||
|
||||
if connector is None:
|
||||
connector = get_connector()
|
||||
|
||||
# Build placeholders for cluster IDs
|
||||
placeholders = ", ".join(["%s"] * len(cluster_ids))
|
||||
|
||||
# Query to check if patient has any matching SNOMED code
|
||||
query = f'''
|
||||
SELECT
|
||||
pc."SNOMEDCode",
|
||||
cc."Cluster_ID",
|
||||
cc."SNOMEDDescription"
|
||||
FROM DATA_HUB.PHM."PrimaryCareClinicalCoding" pc
|
||||
INNER JOIN DATA_HUB.PHM."ClinicalCodingClusterSnomedCodes" cc
|
||||
ON pc."SNOMEDCode" = cc."SNOMEDCode"
|
||||
WHERE pc."PatientPseudonym" = %s
|
||||
AND cc."Cluster_ID" IN ({placeholders})
|
||||
'''
|
||||
|
||||
params = [patient_pseudonym] + cluster_ids
|
||||
|
||||
if before_date:
|
||||
query += ' AND pc."EventDateTime" < %s'
|
||||
params.append(before_date.isoformat())
|
||||
|
||||
query += ' LIMIT 1'
|
||||
|
||||
try:
|
||||
results = connector.execute_dict(query, tuple(params))
|
||||
|
||||
if results:
|
||||
row = results[0]
|
||||
return (
|
||||
True,
|
||||
row.get("Cluster_ID"),
|
||||
row.get("SNOMEDCode"),
|
||||
row.get("SNOMEDDescription"),
|
||||
)
|
||||
|
||||
return False, None, None, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking indication for patient '{patient_pseudonym}': {e}")
|
||||
return False, None, None, None
|
||||
|
||||
|
||||
def validate_indication(
|
||||
patient_pseudonym: str,
|
||||
drug_name: str,
|
||||
connector: Optional[SnowflakeConnector] = None,
|
||||
db_manager: Optional[DatabaseManager] = None,
|
||||
before_date: Optional[date] = None,
|
||||
) -> IndicationValidationResult:
|
||||
"""
|
||||
Validate that a patient has an appropriate indication for a drug.
|
||||
|
||||
Full validation workflow:
|
||||
1. Get drug's valid indication clusters from local mapping
|
||||
2. Check if patient has any matching SNOMED codes in GP records
|
||||
3. Return detailed validation result
|
||||
|
||||
Args:
|
||||
patient_pseudonym: Patient's pseudonymised NHS number
|
||||
drug_name: Drug name to validate indication for
|
||||
connector: Optional SnowflakeConnector
|
||||
db_manager: Optional DatabaseManager
|
||||
before_date: Optional date - only check diagnoses before this date
|
||||
|
||||
Returns:
|
||||
IndicationValidationResult with validation details
|
||||
"""
|
||||
result = IndicationValidationResult(
|
||||
patient_pseudonym=patient_pseudonym,
|
||||
drug_name=drug_name,
|
||||
has_valid_indication=False,
|
||||
)
|
||||
|
||||
# Step 1: Get drug's cluster mappings
|
||||
cluster_ids = get_drug_cluster_ids(drug_name, db_manager)
|
||||
|
||||
if not cluster_ids:
|
||||
result.error_message = f"No cluster mappings found for drug '{drug_name}'"
|
||||
result.source = "NONE"
|
||||
return result
|
||||
|
||||
result.checked_clusters = cluster_ids
|
||||
|
||||
# Step 2: Check Snowflake availability
|
||||
if not SNOWFLAKE_AVAILABLE:
|
||||
result.error_message = "Snowflake connector not installed"
|
||||
result.source = "NONE"
|
||||
return result
|
||||
|
||||
if not is_snowflake_configured():
|
||||
result.error_message = "Snowflake not configured"
|
||||
result.source = "NONE"
|
||||
return result
|
||||
|
||||
# Step 3: Check patient GP records
|
||||
has_indication, matched_cluster, matched_code, matched_desc = patient_has_indication(
|
||||
patient_pseudonym=patient_pseudonym,
|
||||
cluster_ids=cluster_ids,
|
||||
connector=connector,
|
||||
before_date=before_date,
|
||||
)
|
||||
|
||||
result.has_valid_indication = has_indication
|
||||
result.matched_cluster_id = matched_cluster
|
||||
result.matched_snomed_code = matched_code
|
||||
result.matched_snomed_description = matched_desc
|
||||
result.source = "GP_SNOMED" if has_indication else "NONE"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_indication_match_rate(
|
||||
drug_name: str,
|
||||
patient_pseudonyms: list[str],
|
||||
connector: Optional[SnowflakeConnector] = None,
|
||||
db_manager: Optional[DatabaseManager] = None,
|
||||
sample_unmatched_count: int = 10,
|
||||
) -> DrugIndicationMatchRate:
|
||||
"""
|
||||
Calculate indication match rate for a drug across a list of patients.
|
||||
|
||||
Args:
|
||||
drug_name: Drug name to check
|
||||
patient_pseudonyms: List of patient pseudonymised NHS numbers
|
||||
connector: Optional SnowflakeConnector
|
||||
db_manager: Optional DatabaseManager
|
||||
sample_unmatched_count: Number of unmatched patient IDs to include in sample
|
||||
|
||||
Returns:
|
||||
DrugIndicationMatchRate with match statistics
|
||||
"""
|
||||
if connector is None and SNOWFLAKE_AVAILABLE and is_snowflake_configured():
|
||||
connector = get_connector()
|
||||
|
||||
cluster_ids = get_drug_cluster_ids(drug_name, db_manager)
|
||||
|
||||
total = len(patient_pseudonyms)
|
||||
matched = 0
|
||||
unmatched = 0
|
||||
sample_unmatched: list[str] = []
|
||||
|
||||
if not cluster_ids:
|
||||
logger.warning(f"No cluster mappings for drug '{drug_name}' - all patients will be unmatched")
|
||||
return DrugIndicationMatchRate(
|
||||
drug_name=drug_name,
|
||||
total_patients=total,
|
||||
patients_with_indication=0,
|
||||
patients_without_indication=total,
|
||||
match_rate=0.0,
|
||||
clusters_checked=[],
|
||||
sample_unmatched=patient_pseudonyms[:sample_unmatched_count],
|
||||
)
|
||||
|
||||
for i, pseudonym in enumerate(patient_pseudonyms):
|
||||
if i > 0 and i % 100 == 0:
|
||||
logger.info(f"Validating indications: {i}/{total} ({100*i/total:.1f}%)")
|
||||
|
||||
has_indication, _, _, _ = patient_has_indication(
|
||||
patient_pseudonym=pseudonym,
|
||||
cluster_ids=cluster_ids,
|
||||
connector=connector,
|
||||
)
|
||||
|
||||
if has_indication:
|
||||
matched += 1
|
||||
else:
|
||||
unmatched += 1
|
||||
if len(sample_unmatched) < sample_unmatched_count:
|
||||
sample_unmatched.append(pseudonym)
|
||||
|
||||
match_rate = matched / total if total > 0 else 0.0
|
||||
|
||||
logger.info(f"Indication match rate for '{drug_name}': {100*match_rate:.1f}% ({matched}/{total})")
|
||||
|
||||
return DrugIndicationMatchRate(
|
||||
drug_name=drug_name,
|
||||
total_patients=total,
|
||||
patients_with_indication=matched,
|
||||
patients_without_indication=unmatched,
|
||||
match_rate=match_rate,
|
||||
clusters_checked=cluster_ids,
|
||||
sample_unmatched=sample_unmatched,
|
||||
)
|
||||
|
||||
|
||||
def batch_validate_indications(
|
||||
patient_drug_pairs: list[tuple[str, str]],
|
||||
connector: Optional[SnowflakeConnector] = None,
|
||||
db_manager: Optional[DatabaseManager] = None,
|
||||
progress_callback: Optional[Callable[[int, int], None]] = None,
|
||||
) -> list[IndicationValidationResult]:
|
||||
"""
|
||||
Validate indications for multiple patient-drug pairs efficiently.
|
||||
|
||||
Args:
|
||||
patient_drug_pairs: List of (patient_pseudonym, drug_name) tuples
|
||||
connector: Optional SnowflakeConnector
|
||||
db_manager: Optional DatabaseManager
|
||||
progress_callback: Optional callback(current, total) for progress updates
|
||||
|
||||
Returns:
|
||||
List of IndicationValidationResult for each pair
|
||||
"""
|
||||
results = []
|
||||
total = len(patient_drug_pairs)
|
||||
|
||||
# Cache cluster lookups by drug
|
||||
drug_clusters_cache = {}
|
||||
|
||||
for i, (pseudonym, drug_name) in enumerate(patient_drug_pairs):
|
||||
if progress_callback:
|
||||
progress_callback(i + 1, total)
|
||||
|
||||
# Get clusters from cache or lookup
|
||||
drug_upper = drug_name.upper()
|
||||
if drug_upper not in drug_clusters_cache:
|
||||
drug_clusters_cache[drug_upper] = get_drug_cluster_ids(drug_name, db_manager)
|
||||
|
||||
cluster_ids = drug_clusters_cache[drug_upper]
|
||||
|
||||
if not cluster_ids:
|
||||
results.append(IndicationValidationResult(
|
||||
patient_pseudonym=pseudonym,
|
||||
drug_name=drug_name,
|
||||
has_valid_indication=False,
|
||||
source="NONE",
|
||||
error_message=f"No cluster mappings for drug '{drug_name}'",
|
||||
))
|
||||
continue
|
||||
|
||||
# Check patient indication
|
||||
has_indication, matched_cluster, matched_code, matched_desc = patient_has_indication(
|
||||
patient_pseudonym=pseudonym,
|
||||
cluster_ids=cluster_ids,
|
||||
connector=connector,
|
||||
)
|
||||
|
||||
results.append(IndicationValidationResult(
|
||||
patient_pseudonym=pseudonym,
|
||||
drug_name=drug_name,
|
||||
has_valid_indication=has_indication,
|
||||
matched_cluster_id=matched_cluster,
|
||||
matched_snomed_code=matched_code,
|
||||
matched_snomed_description=matched_desc,
|
||||
checked_clusters=cluster_ids,
|
||||
source="GP_SNOMED" if has_indication else "NONE",
|
||||
))
|
||||
|
||||
matched_count = sum(1 for r in results if r.has_valid_indication)
|
||||
logger.info(f"Batch validation complete: {matched_count}/{total} ({100*matched_count/total:.1f}%) with valid indications")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def get_available_clusters(
|
||||
connector: Optional[SnowflakeConnector] = None,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Get list of all available SNOMED clusters from Snowflake.
|
||||
|
||||
Returns:
|
||||
List of dicts with cluster_id, cluster_description, code_count
|
||||
"""
|
||||
if not SNOWFLAKE_AVAILABLE or not is_snowflake_configured():
|
||||
logger.warning("Snowflake not available - cannot list clusters")
|
||||
return []
|
||||
|
||||
if connector is None:
|
||||
connector = get_connector()
|
||||
|
||||
query = '''
|
||||
SELECT
|
||||
"Cluster_ID",
|
||||
"Cluster_Description",
|
||||
COUNT(DISTINCT "SNOMEDCode") as code_count
|
||||
FROM DATA_HUB.PHM."ClinicalCodingClusterSnomedCodes"
|
||||
GROUP BY "Cluster_ID", "Cluster_Description"
|
||||
ORDER BY "Cluster_ID"
|
||||
'''
|
||||
|
||||
try:
|
||||
results = connector.execute_dict(query)
|
||||
|
||||
clusters = []
|
||||
for row in results:
|
||||
clusters.append({
|
||||
"cluster_id": row.get("Cluster_ID"),
|
||||
"cluster_description": row.get("Cluster_Description"),
|
||||
"code_count": row.get("code_count", 0),
|
||||
})
|
||||
|
||||
logger.info(f"Found {len(clusters)} available SNOMED clusters")
|
||||
return clusters
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting available clusters: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# Export public API
|
||||
__all__ = [
|
||||
"ClusterSnomedCodes",
|
||||
"IndicationValidationResult",
|
||||
"DrugIndicationMatchRate",
|
||||
"get_drug_clusters",
|
||||
"get_drug_cluster_ids",
|
||||
"get_cluster_snomed_codes",
|
||||
"patient_has_indication",
|
||||
"validate_indication",
|
||||
"get_indication_match_rate",
|
||||
"batch_validate_indications",
|
||||
"get_available_clusters",
|
||||
]
|
||||
@@ -0,0 +1,399 @@
|
||||
"""
|
||||
Data loader abstractions for NHS High-Cost Drug Patient Pathway Analysis Tool.
|
||||
|
||||
Provides a unified interface for loading patient intervention data from:
|
||||
- CSV/Parquet files (current behavior)
|
||||
- SQLite database (new, faster approach)
|
||||
- Snowflake (future, direct from warehouse)
|
||||
|
||||
The DataLoader ABC defines the contract for all loader implementations.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from core import PathConfig, default_paths
|
||||
from core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadResult:
|
||||
"""Result of a data load operation.
|
||||
|
||||
Attributes:
|
||||
df: The loaded DataFrame with processed patient intervention data
|
||||
source: Description of the data source (e.g., "csv:/path/to/file.csv", "sqlite:fact_interventions")
|
||||
row_count: Number of rows loaded
|
||||
columns: List of column names in the DataFrame
|
||||
load_time_seconds: Time taken to load the data
|
||||
"""
|
||||
df: pd.DataFrame
|
||||
source: str
|
||||
row_count: int
|
||||
columns: list[str] = field(default_factory=list)
|
||||
load_time_seconds: float = 0.0
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.columns:
|
||||
self.columns = list(self.df.columns)
|
||||
|
||||
|
||||
# Expected columns in a processed DataFrame
|
||||
# These are the columns that generate_graph() expects to receive
|
||||
REQUIRED_COLUMNS = [
|
||||
"UPID", # Unique Patient ID (Provider Code prefix + PersonKey)
|
||||
"Drug Name", # Standardized drug name
|
||||
"Intervention Date", # Date of intervention
|
||||
"Price Actual", # Cost of intervention
|
||||
"OrganisationName", # NHS Trust name
|
||||
"Directory", # Medical specialty/directory
|
||||
"Provider Code", # NHS provider code
|
||||
"PersonKey", # Patient identifier within provider
|
||||
]
|
||||
|
||||
# Additional columns that are useful but not strictly required
|
||||
OPTIONAL_COLUMNS = [
|
||||
"UPIDTreatment", # UPID + Drug Name combo (created by generate_graph)
|
||||
"Treatment Function Code", # NHS treatment function code
|
||||
"Additional Detail 1",
|
||||
"Additional Detail 2",
|
||||
"Additional Detail 3",
|
||||
"Additional Detail 4",
|
||||
"Additional Detail 5",
|
||||
]
|
||||
|
||||
|
||||
class DataLoader(ABC):
|
||||
"""Abstract base class for data loaders.
|
||||
|
||||
All data loaders must implement the load() method which returns
|
||||
a DataFrame ready for use by generate_graph().
|
||||
|
||||
The returned DataFrame must contain REQUIRED_COLUMNS at minimum.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def load(self) -> LoadResult:
|
||||
"""Load and process patient intervention data.
|
||||
|
||||
Returns:
|
||||
LoadResult containing the processed DataFrame and metadata.
|
||||
The DataFrame must contain all REQUIRED_COLUMNS.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the data source doesn't exist
|
||||
ValueError: If the data is malformed or missing required columns
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_source(self) -> tuple[bool, str]:
|
||||
"""Check if the data source is valid and accessible.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, message).
|
||||
If is_valid is False, message explains the issue.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def source_description(self) -> str:
|
||||
"""Human-readable description of the data source."""
|
||||
pass
|
||||
|
||||
def validate_dataframe(self, df: pd.DataFrame) -> tuple[bool, list[str]]:
|
||||
"""Validate that a DataFrame has all required columns.
|
||||
|
||||
Args:
|
||||
df: DataFrame to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, missing_columns).
|
||||
If is_valid is False, missing_columns lists what's missing.
|
||||
"""
|
||||
missing = [col for col in REQUIRED_COLUMNS if col not in df.columns]
|
||||
return len(missing) == 0, missing
|
||||
|
||||
|
||||
class FileDataLoader(DataLoader):
|
||||
"""Loads data from CSV or Parquet files.
|
||||
|
||||
This replicates the current behavior of dashboard_gui.main():
|
||||
1. Read CSV or Parquet file
|
||||
2. Apply patient_id() transformation
|
||||
3. Convert dates
|
||||
4. Apply drug_names() standardization
|
||||
5. Clean organization names
|
||||
6. Apply department_identification()
|
||||
|
||||
Args:
|
||||
file_path: Path to the CSV or Parquet file
|
||||
paths: PathConfig for reference data file locations (uses default_paths if None)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: Path | str,
|
||||
paths: Optional[PathConfig] = None,
|
||||
):
|
||||
self.file_path = Path(file_path)
|
||||
self.paths = paths or default_paths
|
||||
|
||||
def validate_source(self) -> tuple[bool, str]:
|
||||
"""Check if the file exists and has a supported extension."""
|
||||
if not self.file_path.exists():
|
||||
return False, f"File not found: {self.file_path}"
|
||||
|
||||
ext = self.file_path.suffix.lower()
|
||||
if ext not in ('.csv', '.parquet'):
|
||||
return False, f"Unsupported file type: {ext}. Must be .csv or .parquet"
|
||||
|
||||
return True, "OK"
|
||||
|
||||
@property
|
||||
def source_description(self) -> str:
|
||||
return f"file:{self.file_path}"
|
||||
|
||||
def load(self) -> LoadResult:
|
||||
"""Load and process data from CSV or Parquet file.
|
||||
|
||||
Applies the same transformation pipeline as the original
|
||||
dashboard_gui.main() function.
|
||||
"""
|
||||
import time
|
||||
from tools import data
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Validate source before loading
|
||||
is_valid, msg = self.validate_source()
|
||||
if not is_valid:
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
# Read file based on extension
|
||||
ext = self.file_path.suffix.lower()
|
||||
logger.info(f"Reading {ext} file: {self.file_path}")
|
||||
|
||||
if ext == '.csv':
|
||||
df_raw = pd.read_csv(self.file_path, low_memory=False)
|
||||
else: # .parquet
|
||||
df_raw = pd.read_parquet(self.file_path)
|
||||
|
||||
logger.info(f"File read successfully. {len(df_raw)} rows.")
|
||||
|
||||
# Apply transformations (same as dashboard_gui.main())
|
||||
df = data.patient_id(df_raw)
|
||||
logger.info("Patient ID processing complete.")
|
||||
|
||||
df['Intervention Date'] = pd.to_datetime(df['Intervention Date'], format="%Y-%m-%d")
|
||||
logger.info("Date conversion complete.")
|
||||
|
||||
# Preserve original drug name before standardization (for SQLite storage)
|
||||
df['Drug Name Raw'] = df['Drug Name'].copy()
|
||||
|
||||
df = data.drug_names(df, self.paths)
|
||||
logger.info("Drug name processing complete.")
|
||||
|
||||
df['OrganisationName'] = df['OrganisationName'].str.replace(',', '')
|
||||
logger.info("Organisation name cleaning complete.")
|
||||
|
||||
df = data.department_identification(df, self.paths)
|
||||
logger.info("Department identification complete.")
|
||||
|
||||
# Validate result
|
||||
is_valid, missing = self.validate_dataframe(df)
|
||||
if not is_valid:
|
||||
raise ValueError(f"Processed DataFrame missing required columns: {missing}")
|
||||
|
||||
load_time = time.time() - start_time
|
||||
logger.info(f"Data loading complete. {len(df)} rows in {load_time:.2f}s")
|
||||
|
||||
return LoadResult(
|
||||
df=df,
|
||||
source=self.source_description,
|
||||
row_count=len(df),
|
||||
load_time_seconds=load_time,
|
||||
)
|
||||
|
||||
|
||||
class SQLiteDataLoader(DataLoader):
|
||||
"""Loads data from SQLite fact_interventions table.
|
||||
|
||||
This provides faster loading by reading pre-processed data from SQLite
|
||||
instead of re-processing CSV files each time.
|
||||
|
||||
The SQLite database must have been populated by the migration scripts.
|
||||
|
||||
Args:
|
||||
db_path: Path to the SQLite database (uses default if None)
|
||||
date_range: Optional tuple of (start_date, end_date) to filter data
|
||||
trusts: Optional list of trust names to filter
|
||||
drugs: Optional list of drug names to filter
|
||||
directories: Optional list of directories to filter
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: Optional[Path | str] = None,
|
||||
date_range: Optional[tuple[date, date]] = None,
|
||||
trusts: Optional[list[str]] = None,
|
||||
drugs: Optional[list[str]] = None,
|
||||
directories: Optional[list[str]] = None,
|
||||
):
|
||||
from data_processing.database import default_db_config
|
||||
|
||||
self.db_path = Path(db_path) if db_path else Path(default_db_config.db_path)
|
||||
self.date_range = date_range
|
||||
self.trusts = trusts
|
||||
self.drugs = drugs
|
||||
self.directories = directories
|
||||
|
||||
def validate_source(self) -> tuple[bool, str]:
|
||||
"""Check if the database exists and has the fact_interventions table."""
|
||||
if not self.db_path.exists():
|
||||
return False, f"Database not found: {self.db_path}"
|
||||
|
||||
# Check if fact_interventions table exists
|
||||
from data_processing.database import DatabaseManager, DatabaseConfig
|
||||
|
||||
config = DatabaseConfig(db_path=self.db_path)
|
||||
manager = DatabaseManager(config)
|
||||
|
||||
if not manager.table_exists("fact_interventions"):
|
||||
return False, "fact_interventions table not found in database"
|
||||
|
||||
count = manager.get_table_count("fact_interventions")
|
||||
if count == 0:
|
||||
return False, "fact_interventions table is empty"
|
||||
|
||||
return True, f"OK ({count:,} rows available)"
|
||||
|
||||
@property
|
||||
def source_description(self) -> str:
|
||||
return f"sqlite:{self.db_path}"
|
||||
|
||||
def load(self) -> LoadResult:
|
||||
"""Load data from SQLite fact_interventions table.
|
||||
|
||||
Maps SQLite column names to the expected DataFrame column names.
|
||||
Applies optional filters for date range, trusts, drugs, directories.
|
||||
"""
|
||||
import time
|
||||
from data_processing.database import DatabaseManager, DatabaseConfig
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Validate source
|
||||
is_valid, msg = self.validate_source()
|
||||
if not is_valid:
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
logger.info(f"Loading data from SQLite: {self.db_path}")
|
||||
|
||||
# Build query with optional filters
|
||||
query = """
|
||||
SELECT
|
||||
upid AS "UPID",
|
||||
provider_code AS "Provider Code",
|
||||
person_key AS "PersonKey",
|
||||
drug_name_std AS "Drug Name",
|
||||
intervention_date AS "Intervention Date",
|
||||
price_actual AS "Price Actual",
|
||||
org_name AS "OrganisationName",
|
||||
directory AS "Directory",
|
||||
treatment_function_code AS "Treatment Function Code",
|
||||
additional_detail_1 AS "Additional Detail 1",
|
||||
additional_detail_2 AS "Additional Detail 2",
|
||||
additional_detail_3 AS "Additional Detail 3",
|
||||
additional_detail_4 AS "Additional Detail 4",
|
||||
additional_detail_5 AS "Additional Detail 5"
|
||||
FROM fact_interventions
|
||||
WHERE 1=1
|
||||
"""
|
||||
params = []
|
||||
|
||||
if self.date_range:
|
||||
start, end = self.date_range
|
||||
query += " AND intervention_date >= ? AND intervention_date < ?"
|
||||
params.extend([str(start), str(end)])
|
||||
|
||||
if self.trusts:
|
||||
placeholders = ','.join('?' * len(self.trusts))
|
||||
query += f" AND org_name IN ({placeholders})"
|
||||
params.extend(self.trusts)
|
||||
|
||||
if self.drugs:
|
||||
placeholders = ','.join('?' * len(self.drugs))
|
||||
query += f" AND drug_name_std IN ({placeholders})"
|
||||
params.extend(self.drugs)
|
||||
|
||||
if self.directories:
|
||||
placeholders = ','.join('?' * len(self.directories))
|
||||
query += f" AND directory IN ({placeholders})"
|
||||
params.extend(self.directories)
|
||||
|
||||
# Execute query
|
||||
config = DatabaseConfig(db_path=self.db_path)
|
||||
manager = DatabaseManager(config)
|
||||
|
||||
with manager.get_connection() as conn:
|
||||
df = pd.read_sql_query(query, conn, params=params)
|
||||
|
||||
# Convert intervention_date to datetime
|
||||
df['Intervention Date'] = pd.to_datetime(df['Intervention Date'])
|
||||
|
||||
logger.info(f"Loaded {len(df)} rows from SQLite")
|
||||
|
||||
# Validate result
|
||||
is_valid, missing = self.validate_dataframe(df)
|
||||
if not is_valid:
|
||||
raise ValueError(f"SQLite data missing required columns: {missing}")
|
||||
|
||||
load_time = time.time() - start_time
|
||||
logger.info(f"SQLite data loading complete. {len(df)} rows in {load_time:.2f}s")
|
||||
|
||||
return LoadResult(
|
||||
df=df,
|
||||
source=self.source_description,
|
||||
row_count=len(df),
|
||||
load_time_seconds=load_time,
|
||||
)
|
||||
|
||||
|
||||
def get_loader(
|
||||
source: str | Path,
|
||||
paths: Optional[PathConfig] = None,
|
||||
**kwargs
|
||||
) -> DataLoader:
|
||||
"""Factory function to create the appropriate DataLoader.
|
||||
|
||||
Args:
|
||||
source: Either a file path (CSV/Parquet) or "sqlite" for database
|
||||
paths: PathConfig for reference data (used by FileDataLoader)
|
||||
**kwargs: Additional arguments passed to the loader constructor
|
||||
|
||||
Returns:
|
||||
Appropriate DataLoader instance
|
||||
|
||||
Examples:
|
||||
>>> loader = get_loader("data/activity.csv")
|
||||
>>> loader = get_loader("data/activity.parquet")
|
||||
>>> loader = get_loader("sqlite")
|
||||
>>> loader = get_loader("sqlite", date_range=(date(2024, 1, 1), date(2024, 12, 31)))
|
||||
"""
|
||||
source_str = str(source).lower()
|
||||
|
||||
if source_str == "sqlite":
|
||||
return SQLiteDataLoader(**kwargs)
|
||||
|
||||
# Assume it's a file path
|
||||
path = Path(source)
|
||||
return FileDataLoader(file_path=path, paths=paths)
|
||||
@@ -0,0 +1,593 @@
|
||||
"""
|
||||
Database migration script for NHS High-Cost Drug Patient Pathway Analysis Tool.
|
||||
|
||||
Provides functions to initialize the SQLite database schema and CLI interface
|
||||
for running migrations from the command line.
|
||||
|
||||
Usage:
|
||||
# Initialize database (creates all tables)
|
||||
python -m data_processing.migrate
|
||||
|
||||
# Drop existing tables and reinitialize
|
||||
python -m data_processing.migrate --drop-existing
|
||||
|
||||
# Show current database status
|
||||
python -m data_processing.migrate --status
|
||||
|
||||
# Migrate all reference data from CSV files
|
||||
python -m data_processing.migrate --reference-data
|
||||
|
||||
# Migrate reference data with verification
|
||||
python -m data_processing.migrate --reference-data --verify
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from core.logging_config import setup_logging, get_logger
|
||||
from data_processing.database import DatabaseManager, DatabaseConfig
|
||||
from core import PathConfig, default_paths
|
||||
from data_processing.schema import (
|
||||
create_all_tables,
|
||||
drop_all_tables,
|
||||
verify_all_tables_exist,
|
||||
get_all_table_counts,
|
||||
)
|
||||
from data_processing.reference_data import (
|
||||
MigrationResult,
|
||||
migrate_drug_names,
|
||||
migrate_organizations,
|
||||
migrate_directories,
|
||||
migrate_drug_directory_map,
|
||||
migrate_drug_indication_clusters,
|
||||
verify_drug_names_migration,
|
||||
verify_organizations_migration,
|
||||
verify_directories_migration,
|
||||
verify_drug_directory_map_migration,
|
||||
verify_drug_indication_clusters_migration,
|
||||
)
|
||||
from data_processing.patient_data import (
|
||||
load_patient_data,
|
||||
refresh_patient_treatment_summary,
|
||||
get_patient_data_stats,
|
||||
verify_mv_consistency,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def initialize_database(
|
||||
db_manager: Optional[DatabaseManager] = None,
|
||||
drop_existing: bool = False,
|
||||
confirm_drop: bool = True
|
||||
) -> bool:
|
||||
"""
|
||||
Initialize the database with all required tables.
|
||||
|
||||
Creates all tables defined in the schema (reference tables, fact tables,
|
||||
materialized views, and file tracking tables). Uses IF NOT EXISTS so
|
||||
safe to run multiple times.
|
||||
|
||||
Args:
|
||||
db_manager: DatabaseManager instance. Uses default if not provided.
|
||||
drop_existing: If True, drops all existing tables before creating.
|
||||
confirm_drop: If True and drop_existing=True, prompts for confirmation.
|
||||
Set to False for non-interactive use.
|
||||
|
||||
Returns:
|
||||
True if initialization succeeded, False otherwise.
|
||||
"""
|
||||
if db_manager is None:
|
||||
db_manager = DatabaseManager()
|
||||
|
||||
logger.info(f"Initializing database at: {db_manager.db_path}")
|
||||
|
||||
# Handle drop existing with confirmation
|
||||
if drop_existing:
|
||||
if confirm_drop:
|
||||
print(f"\nWARNING: This will delete ALL data from the database:")
|
||||
print(f" {db_manager.db_path}\n")
|
||||
response = input("Are you sure you want to continue? (yes/no): ")
|
||||
if response.lower() not in ("yes", "y"):
|
||||
print("Operation cancelled.")
|
||||
return False
|
||||
|
||||
if db_manager.exists:
|
||||
logger.warning("Dropping existing tables...")
|
||||
with db_manager.get_connection() as conn:
|
||||
drop_all_tables(conn)
|
||||
conn.commit()
|
||||
logger.info("Existing tables dropped")
|
||||
else:
|
||||
logger.info("Database does not exist yet, nothing to drop")
|
||||
|
||||
# Create all tables
|
||||
try:
|
||||
with db_manager.get_transaction() as conn:
|
||||
create_all_tables(conn)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create tables: {e}")
|
||||
return False
|
||||
|
||||
# Verify all tables were created
|
||||
with db_manager.get_connection() as conn:
|
||||
missing = verify_all_tables_exist(conn)
|
||||
|
||||
if missing:
|
||||
logger.error(f"Table creation failed. Missing tables: {missing}")
|
||||
return False
|
||||
|
||||
logger.info("All tables created successfully")
|
||||
return True
|
||||
|
||||
|
||||
def migrate_all_reference_data(
|
||||
db_manager: Optional[DatabaseManager] = None,
|
||||
paths: Optional[PathConfig] = None,
|
||||
verify: bool = False
|
||||
) -> tuple[bool, list[MigrationResult]]:
|
||||
"""
|
||||
Run all reference data migrations from CSV files to SQLite tables.
|
||||
|
||||
Migrations are run in order:
|
||||
1. Drug names (drugnames.csv → ref_drug_names)
|
||||
2. Organizations (org_codes.csv → ref_organizations)
|
||||
3. Directories (directory_list.csv → ref_directories)
|
||||
4. Drug-directory mappings (drug_directory_list.csv → ref_drug_directory_map)
|
||||
|
||||
Args:
|
||||
db_manager: DatabaseManager instance. Uses default if not provided.
|
||||
paths: PathConfig instance for locating CSV files. Uses default if not provided.
|
||||
verify: If True, runs verification after each migration.
|
||||
|
||||
Returns:
|
||||
Tuple of (all_success: bool, results: list of MigrationResult)
|
||||
"""
|
||||
if db_manager is None:
|
||||
db_manager = DatabaseManager()
|
||||
if paths is None:
|
||||
paths = default_paths
|
||||
|
||||
results: list[MigrationResult] = []
|
||||
all_success = True
|
||||
|
||||
# Define migrations in order
|
||||
# Note: drug_indication_clusters uses a different signature (csv_path instead of paths)
|
||||
migrations = [
|
||||
("Drug names", migrate_drug_names, verify_drug_names_migration if verify else None, True),
|
||||
("Organizations", migrate_organizations, verify_organizations_migration if verify else None, True),
|
||||
("Directories", migrate_directories, verify_directories_migration if verify else None, True),
|
||||
("Drug-directory map", migrate_drug_directory_map, verify_drug_directory_map_migration if verify else None, True),
|
||||
("Drug indication clusters", migrate_drug_indication_clusters, verify_drug_indication_clusters_migration if verify else None, False),
|
||||
]
|
||||
|
||||
logger.info(f"Starting reference data migrations ({len(migrations)} tables)")
|
||||
|
||||
for name, migrate_fn, verify_fn, uses_paths in migrations:
|
||||
logger.info(f"Migrating: {name}...")
|
||||
|
||||
# Run migration (some use paths parameter, some use csv_path)
|
||||
if uses_paths:
|
||||
result = migrate_fn(db_manager=db_manager, paths=paths) # type: ignore[operator]
|
||||
else:
|
||||
# Drug indication clusters uses csv_path instead of paths
|
||||
result = migrate_fn(db_manager=db_manager) # type: ignore[operator]
|
||||
results.append(result)
|
||||
|
||||
if not result.success:
|
||||
logger.error(f"Migration failed: {name} - {result.error_message}")
|
||||
all_success = False
|
||||
continue
|
||||
|
||||
logger.info(f" {result}")
|
||||
|
||||
# Run verification if requested
|
||||
if verify_fn is not None:
|
||||
logger.info(f" Verifying {name}...")
|
||||
if uses_paths:
|
||||
verified, verify_msg = verify_fn(db_manager=db_manager, paths=paths) # type: ignore[call-arg]
|
||||
else:
|
||||
verified, verify_msg = verify_fn(db_manager=db_manager) # type: ignore[call-arg]
|
||||
if verified:
|
||||
logger.info(f" OK: {verify_msg}")
|
||||
else:
|
||||
logger.error(f" FAILED: Verification failed: {verify_msg}")
|
||||
all_success = False
|
||||
|
||||
# Summary
|
||||
successful = sum(1 for r in results if r.success)
|
||||
logger.info(f"Reference data migrations complete: {successful}/{len(results)} succeeded")
|
||||
|
||||
return all_success, results
|
||||
|
||||
|
||||
def print_migration_summary(results: list[MigrationResult]) -> None:
|
||||
"""Print a summary of migration results to stdout."""
|
||||
print("\n=== Reference Data Migration Summary ===\n")
|
||||
|
||||
for result in results:
|
||||
status = "[OK]" if result.success else "[FAILED]"
|
||||
print(f"{status} {result.table_name}")
|
||||
if result.success:
|
||||
print(f" Read: {result.rows_read}, Inserted: {result.rows_inserted}, Skipped: {result.rows_skipped}")
|
||||
else:
|
||||
print(f" Error: {result.error_message}")
|
||||
|
||||
successful = sum(1 for r in results if r.success)
|
||||
print(f"\nTotal: {successful}/{len(results)} migrations succeeded")
|
||||
print()
|
||||
|
||||
|
||||
def create_progress_reporter(description: str = "Loading", width: int = 40):
|
||||
"""
|
||||
Create a progress callback that prints a progress bar to stdout.
|
||||
|
||||
Args:
|
||||
description: Label to show before the progress bar.
|
||||
width: Width of the progress bar in characters.
|
||||
|
||||
Returns:
|
||||
Callback function(current, total) that prints progress.
|
||||
"""
|
||||
last_percent = [-1] # Use list to allow mutation in closure
|
||||
|
||||
def report_progress(current: int, total: int) -> None:
|
||||
"""Print a progress bar showing current/total progress."""
|
||||
if total == 0:
|
||||
percent = 100
|
||||
else:
|
||||
percent = int(100 * current / total)
|
||||
|
||||
# Only update display when percentage changes (avoid excessive output)
|
||||
if percent == last_percent[0]:
|
||||
return
|
||||
last_percent[0] = percent
|
||||
|
||||
filled = int(width * current / total) if total > 0 else width
|
||||
bar = "=" * filled + "-" * (width - filled)
|
||||
|
||||
# Use carriage return to overwrite the line
|
||||
sys.stdout.write(f"\r{description}: [{bar}] {percent:3d}% ({current:,}/{total:,})")
|
||||
sys.stdout.flush()
|
||||
|
||||
# Print newline when complete
|
||||
if current >= total:
|
||||
print()
|
||||
|
||||
return report_progress
|
||||
|
||||
|
||||
def load_patient_data_cli(
|
||||
file_path: Path,
|
||||
db_manager: Optional[DatabaseManager] = None,
|
||||
paths: Optional[PathConfig] = None,
|
||||
force: bool = False,
|
||||
refresh_mv: bool = True
|
||||
) -> bool:
|
||||
"""
|
||||
Load patient data from file with CLI progress reporting.
|
||||
|
||||
Args:
|
||||
file_path: Path to CSV or Parquet file.
|
||||
db_manager: DatabaseManager instance. Uses default if not provided.
|
||||
paths: PathConfig for reference data. Uses default if not provided.
|
||||
force: If True, re-process even if file hash matches.
|
||||
refresh_mv: If True, refresh the materialized view after loading.
|
||||
|
||||
Returns:
|
||||
True if loading succeeded, False otherwise.
|
||||
"""
|
||||
if db_manager is None:
|
||||
db_manager = DatabaseManager()
|
||||
if paths is None:
|
||||
paths = default_paths
|
||||
|
||||
print(f"\n=== Loading Patient Data ===\n")
|
||||
print(f"File: {file_path}")
|
||||
|
||||
# Check file exists
|
||||
if not file_path.exists():
|
||||
print(f"ERROR: File not found: {file_path}")
|
||||
return False
|
||||
|
||||
# Calculate and display file info
|
||||
file_size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||
print(f"Size: {file_size_mb:.1f} MB")
|
||||
print()
|
||||
|
||||
# Create progress callback
|
||||
progress_callback = create_progress_reporter("Loading rows", width=40)
|
||||
|
||||
# Load the data
|
||||
result = load_patient_data(
|
||||
file_path=file_path,
|
||||
db_manager=db_manager,
|
||||
paths=paths,
|
||||
batch_size=5000,
|
||||
force=force,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
|
||||
# Print result
|
||||
print()
|
||||
if result.was_already_processed:
|
||||
print("File already processed (same hash). Skipping.")
|
||||
print(f"Use --force to re-process.")
|
||||
elif result.success:
|
||||
print(f"Loaded {result.rows_inserted:,} rows in {result.load_time_seconds:.1f}s")
|
||||
if result.rows_skipped > 0:
|
||||
print(f"Skipped {result.rows_skipped:,} rows (missing UPID or date)")
|
||||
else:
|
||||
print(f"FAILED: {result.error_message}")
|
||||
return False
|
||||
|
||||
# Refresh materialized view if requested
|
||||
if refresh_mv and result.success and not result.was_already_processed:
|
||||
print()
|
||||
print("Refreshing materialized view...")
|
||||
mv_progress = create_progress_reporter("Processing patients", width=40)
|
||||
mv_result = refresh_patient_treatment_summary(
|
||||
db_manager=db_manager,
|
||||
progress_callback=mv_progress
|
||||
)
|
||||
|
||||
if mv_result.success:
|
||||
print(f"MV refreshed: {mv_result.patients_processed:,} patients in {mv_result.refresh_time_seconds:.1f}s")
|
||||
|
||||
# Verify consistency
|
||||
consistent, msg = verify_mv_consistency(db_manager)
|
||||
if consistent:
|
||||
print(f"MV verification: OK")
|
||||
else:
|
||||
print(f"MV verification: FAILED - {msg}")
|
||||
else:
|
||||
print(f"MV refresh FAILED: {mv_result.error_message}")
|
||||
|
||||
# Print summary statistics
|
||||
print()
|
||||
print("=== Patient Data Summary ===")
|
||||
stats = get_patient_data_stats(db_manager)
|
||||
print(f" Total rows: {stats['total_rows']:,}")
|
||||
print(f" Unique patients: {stats['unique_patients']:,}")
|
||||
print(f" Unique drugs: {stats['unique_drugs']:,}")
|
||||
print(f" Unique organizations: {stats['unique_organizations']:,}")
|
||||
if stats['date_range'][0] and stats['date_range'][1]:
|
||||
print(f" Date range: {stats['date_range'][0]} to {stats['date_range'][1]}")
|
||||
print()
|
||||
|
||||
return result.success
|
||||
|
||||
|
||||
def get_database_status(db_manager: Optional[DatabaseManager] = None) -> dict:
|
||||
"""
|
||||
Get the current status of the database.
|
||||
|
||||
Returns:
|
||||
Dictionary with database status information:
|
||||
- exists: Whether the database file exists
|
||||
- path: Path to the database file
|
||||
- size_bytes: Size of database file (if exists)
|
||||
- tables: Dictionary of table names to row counts
|
||||
- missing_tables: List of expected tables that don't exist
|
||||
"""
|
||||
if db_manager is None:
|
||||
db_manager = DatabaseManager()
|
||||
|
||||
status = {
|
||||
"exists": db_manager.exists,
|
||||
"path": str(db_manager.db_path),
|
||||
"size_bytes": None,
|
||||
"tables": {},
|
||||
"missing_tables": [],
|
||||
}
|
||||
|
||||
if db_manager.exists:
|
||||
status["size_bytes"] = db_manager.db_path.stat().st_size
|
||||
|
||||
with db_manager.get_connection() as conn:
|
||||
status["missing_tables"] = verify_all_tables_exist(conn)
|
||||
|
||||
# Get counts for existing tables
|
||||
try:
|
||||
status["tables"] = get_all_table_counts(conn)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not get table counts: {e}")
|
||||
|
||||
return status
|
||||
|
||||
|
||||
def print_database_status(db_manager: Optional[DatabaseManager] = None) -> None:
|
||||
"""Print database status to stdout in a human-readable format."""
|
||||
status = get_database_status(db_manager)
|
||||
|
||||
print("\n=== Database Status ===\n")
|
||||
print(f"Path: {status['path']}")
|
||||
print(f"Exists: {status['exists']}")
|
||||
|
||||
if status["exists"]:
|
||||
size_kb = (status["size_bytes"] or 0) / 1024
|
||||
print(f"Size: {size_kb:.1f} KB")
|
||||
|
||||
if status["missing_tables"]:
|
||||
print(f"\nMissing tables: {', '.join(status['missing_tables'])}")
|
||||
else:
|
||||
print("\nAll expected tables exist.")
|
||||
|
||||
if status["tables"]:
|
||||
print("\nTable row counts:")
|
||||
for table, count in sorted(status["tables"].items()):
|
||||
print(f" {table}: {count:,} rows")
|
||||
else:
|
||||
print("\nDatabase does not exist. Run migration to create it.")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def main():
|
||||
"""CLI entry point for database migration."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Initialize NHS Pathways Analysis SQLite database schema",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python -m data_processing.migrate # Initialize database
|
||||
python -m data_processing.migrate --status # Show database status
|
||||
python -m data_processing.migrate --drop-existing # Reset database
|
||||
python -m data_processing.migrate --reference-data # Migrate reference data
|
||||
python -m data_processing.migrate --reference-data --verify # With verification
|
||||
python -m data_processing.migrate --load-patient-data data.parquet # Load patient data
|
||||
python -m data_processing.migrate --load-patient-data data.csv --force # Force reload
|
||||
python -m data_processing.migrate --db-path ./data/test.db # Custom path
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--status",
|
||||
action="store_true",
|
||||
help="Show current database status and exit"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--drop-existing",
|
||||
action="store_true",
|
||||
help="Drop all existing tables before creating (WARNING: deletes data)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reference-data",
|
||||
action="store_true",
|
||||
help="Migrate all reference data from CSV files to SQLite tables"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verify",
|
||||
action="store_true",
|
||||
help="Verify migrated data matches CSV sources (use with --reference-data)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--db-path",
|
||||
type=Path,
|
||||
help="Path to database file (default: ./data/pathways.db)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--yes", "-y",
|
||||
action="store_true",
|
||||
help="Skip confirmation prompts (for non-interactive use)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", "-v",
|
||||
action="store_true",
|
||||
help="Enable verbose logging"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-patient-data",
|
||||
type=Path,
|
||||
metavar="FILE",
|
||||
help="Load patient data from CSV or Parquet file with progress reporting"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force",
|
||||
action="store_true",
|
||||
help="Force re-processing even if file hash matches (use with --load-patient-data)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-refresh-mv",
|
||||
action="store_true",
|
||||
help="Skip materialized view refresh after loading (use with --load-patient-data)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set up logging
|
||||
log_level = "DEBUG" if args.verbose else "INFO"
|
||||
setup_logging(level=log_level, simple_console=True)
|
||||
|
||||
# Create database manager with optional custom path
|
||||
if args.db_path:
|
||||
config = DatabaseConfig(db_path=args.db_path)
|
||||
db_manager = DatabaseManager(config)
|
||||
else:
|
||||
db_manager = DatabaseManager()
|
||||
|
||||
# Handle --status
|
||||
if args.status:
|
||||
print_database_status(db_manager)
|
||||
return 0
|
||||
|
||||
# Validate configuration
|
||||
config_errors = db_manager.config.validate()
|
||||
if config_errors:
|
||||
for error in config_errors:
|
||||
logger.error(error)
|
||||
return 1
|
||||
|
||||
# Handle --reference-data (migrate reference data from CSV to SQLite)
|
||||
if args.reference_data:
|
||||
# Ensure database exists with tables first
|
||||
if not db_manager.exists:
|
||||
print("Database does not exist. Initializing schema first...")
|
||||
success = initialize_database(db_manager=db_manager)
|
||||
if not success:
|
||||
print("\nDatabase initialization failed. Check logs for details.")
|
||||
return 1
|
||||
|
||||
# Run reference data migrations
|
||||
success, results = migrate_all_reference_data(
|
||||
db_manager=db_manager,
|
||||
paths=default_paths,
|
||||
verify=args.verify
|
||||
)
|
||||
|
||||
print_migration_summary(results)
|
||||
print_database_status(db_manager)
|
||||
|
||||
if success:
|
||||
print("Reference data migration completed successfully.")
|
||||
return 0
|
||||
else:
|
||||
print("Reference data migration completed with errors. Check logs for details.")
|
||||
return 1
|
||||
|
||||
# Handle --load-patient-data (load patient data from CSV/Parquet)
|
||||
if args.load_patient_data:
|
||||
# Ensure database exists with tables first
|
||||
if not db_manager.exists:
|
||||
print("Database does not exist. Initializing schema first...")
|
||||
success = initialize_database(db_manager=db_manager)
|
||||
if not success:
|
||||
print("\nDatabase initialization failed. Check logs for details.")
|
||||
return 1
|
||||
|
||||
# Load patient data with progress reporting
|
||||
success = load_patient_data_cli(
|
||||
file_path=args.load_patient_data,
|
||||
db_manager=db_manager,
|
||||
paths=default_paths,
|
||||
force=args.force,
|
||||
refresh_mv=not args.no_refresh_mv
|
||||
)
|
||||
|
||||
if success:
|
||||
print("Patient data load completed successfully.")
|
||||
return 0
|
||||
else:
|
||||
print("Patient data load failed. Check logs for details.")
|
||||
return 1
|
||||
|
||||
# Run schema migration (default behavior)
|
||||
success = initialize_database(
|
||||
db_manager=db_manager,
|
||||
drop_existing=args.drop_existing,
|
||||
confirm_drop=not args.yes
|
||||
)
|
||||
|
||||
if success:
|
||||
print("\nDatabase initialized successfully.")
|
||||
print_database_status(db_manager)
|
||||
return 0
|
||||
else:
|
||||
print("\nDatabase initialization failed. Check logs for details.")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -0,0 +1,890 @@
|
||||
"""
|
||||
Patient data migration functions for NHS High-Cost Drug Patient Pathway Analysis Tool.
|
||||
|
||||
Provides functions to load patient intervention data from CSV/Parquet files
|
||||
into the SQLite fact_interventions table. Supports:
|
||||
- Batch processing for large files
|
||||
- File hash tracking for incremental updates
|
||||
- Progress reporting during loading
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from core import PathConfig, default_paths
|
||||
from core.logging_config import get_logger
|
||||
from data_processing.database import DatabaseManager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PatientDataLoadResult:
|
||||
"""Results from a patient data load operation."""
|
||||
file_path: str
|
||||
file_hash: str
|
||||
rows_read: int
|
||||
rows_inserted: int
|
||||
rows_skipped: int
|
||||
success: bool
|
||||
error_message: Optional[str] = None
|
||||
load_time_seconds: float = 0.0
|
||||
was_already_processed: bool = False
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.was_already_processed:
|
||||
return f"{self.file_path}: Already processed (same hash)"
|
||||
elif self.success:
|
||||
return (
|
||||
f"{self.file_path}: Loaded {self.rows_inserted:,} rows "
|
||||
f"in {self.load_time_seconds:.1f}s"
|
||||
)
|
||||
else:
|
||||
return f"{self.file_path}: FAILED - {self.error_message}"
|
||||
|
||||
|
||||
def calculate_file_hash(file_path: Path) -> str:
|
||||
"""
|
||||
Calculate SHA256 hash of a file.
|
||||
|
||||
Uses chunked reading to handle large files efficiently.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file.
|
||||
|
||||
Returns:
|
||||
Hex string of SHA256 hash.
|
||||
"""
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8192), b""):
|
||||
sha256_hash.update(chunk)
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
|
||||
def check_file_processed(
|
||||
conn: sqlite3.Connection,
|
||||
file_path: str,
|
||||
file_hash: str
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Check if a file has already been processed with the same hash.
|
||||
|
||||
Args:
|
||||
conn: Database connection.
|
||||
file_path: Full path to the file.
|
||||
file_hash: SHA256 hash of the file.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_processed, old_hash).
|
||||
- If is_processed is True and old_hash == file_hash, file is unchanged.
|
||||
- If is_processed is True and old_hash != file_hash, file has changed.
|
||||
- If is_processed is False, file is new.
|
||||
"""
|
||||
cursor = conn.execute(
|
||||
"SELECT file_hash, status FROM processed_files WHERE file_path = ?",
|
||||
(file_path,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
|
||||
if result is None:
|
||||
return False, None
|
||||
|
||||
old_hash = result["file_hash"]
|
||||
status = result["status"]
|
||||
|
||||
# Only consider it processed if status is success and hash matches
|
||||
if status == "success" and old_hash == file_hash:
|
||||
return True, old_hash
|
||||
|
||||
return False, old_hash
|
||||
|
||||
|
||||
def record_file_processing_start(
|
||||
conn: sqlite3.Connection,
|
||||
file_path: str,
|
||||
file_hash: str,
|
||||
file_size: int,
|
||||
file_modified: datetime
|
||||
) -> None:
|
||||
"""
|
||||
Record that we're starting to process a file.
|
||||
|
||||
Args:
|
||||
conn: Database connection.
|
||||
file_path: Full path to the file.
|
||||
file_hash: SHA256 hash of the file.
|
||||
file_size: File size in bytes.
|
||||
file_modified: File modification timestamp.
|
||||
"""
|
||||
file_name = Path(file_path).name
|
||||
now = datetime.now().isoformat()
|
||||
|
||||
conn.execute("""
|
||||
INSERT INTO processed_files (
|
||||
file_path, file_name, file_hash, file_size_bytes,
|
||||
file_modified_at, status, first_processed_at, last_processed_at
|
||||
) VALUES (?, ?, ?, ?, ?, 'processing', ?, ?)
|
||||
ON CONFLICT(file_path) DO UPDATE SET
|
||||
file_hash = excluded.file_hash,
|
||||
file_size_bytes = excluded.file_size_bytes,
|
||||
file_modified_at = excluded.file_modified_at,
|
||||
status = 'processing',
|
||||
last_processed_at = excluded.last_processed_at,
|
||||
error_message = NULL
|
||||
""", (file_path, file_name, file_hash, file_size, file_modified.isoformat(), now, now))
|
||||
|
||||
|
||||
def record_file_processing_complete(
|
||||
conn: sqlite3.Connection,
|
||||
file_path: str,
|
||||
row_count: int,
|
||||
duration_seconds: float,
|
||||
success: bool,
|
||||
error_message: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Record that file processing has completed.
|
||||
|
||||
Args:
|
||||
conn: Database connection.
|
||||
file_path: Full path to the file.
|
||||
row_count: Number of rows processed.
|
||||
duration_seconds: Time taken to process.
|
||||
success: Whether processing was successful.
|
||||
error_message: Error message if failed.
|
||||
"""
|
||||
status = "success" if success else "error"
|
||||
|
||||
conn.execute("""
|
||||
UPDATE processed_files
|
||||
SET status = ?,
|
||||
row_count = ?,
|
||||
processing_duration_seconds = ?,
|
||||
error_message = ?,
|
||||
last_processed_at = ?
|
||||
WHERE file_path = ?
|
||||
""", (status, row_count, duration_seconds, error_message, datetime.now().isoformat(), file_path))
|
||||
|
||||
|
||||
def load_dataframe_to_sqlite(
|
||||
df: pd.DataFrame,
|
||||
conn: sqlite3.Connection,
|
||||
source_file: str,
|
||||
batch_size: int = 5000,
|
||||
progress_callback: Optional[Callable[[int, int], None]] = None
|
||||
) -> int:
|
||||
"""
|
||||
Load a processed DataFrame into fact_interventions table.
|
||||
|
||||
Args:
|
||||
df: Processed DataFrame with required columns (from FileDataLoader).
|
||||
conn: Database connection.
|
||||
source_file: Source file path for tracking.
|
||||
batch_size: Number of rows to insert per batch.
|
||||
progress_callback: Optional callback(rows_inserted, total_rows) for progress updates.
|
||||
|
||||
Returns:
|
||||
Number of rows inserted.
|
||||
"""
|
||||
# Store the original drug names before processing (for rows where mapping doesn't exist)
|
||||
# The drug_names() transformation sets Drug Name to NULL when no mapping exists.
|
||||
# We need to preserve the original for those cases.
|
||||
|
||||
# Insert SQL columns - always include drug_name_raw
|
||||
insert_columns = [
|
||||
"upid", "provider_code", "person_key",
|
||||
"drug_name_raw", "drug_name_std",
|
||||
"intervention_date", "price_actual",
|
||||
"org_name", "directory",
|
||||
"treatment_function_code",
|
||||
"additional_detail_1", "additional_detail_2", "additional_detail_3",
|
||||
"additional_detail_4", "additional_detail_5",
|
||||
"source_file"
|
||||
]
|
||||
placeholders = ",".join(["?"] * len(insert_columns))
|
||||
insert_sql = f"""
|
||||
INSERT INTO fact_interventions ({",".join(insert_columns)})
|
||||
VALUES ({placeholders})
|
||||
"""
|
||||
|
||||
rows_inserted = 0
|
||||
rows_skipped = 0
|
||||
total_rows = len(df)
|
||||
|
||||
# Process in batches
|
||||
for batch_start in range(0, total_rows, batch_size):
|
||||
batch_end = min(batch_start + batch_size, total_rows)
|
||||
batch_df = df.iloc[batch_start:batch_end]
|
||||
|
||||
# Prepare batch data
|
||||
batch_data = []
|
||||
for _, row in batch_df.iterrows():
|
||||
# Skip rows missing required fields
|
||||
if pd.isna(row.get("UPID")) or pd.isna(row.get("Intervention Date")):
|
||||
rows_skipped += 1
|
||||
continue
|
||||
# Get drug names - raw and standardized
|
||||
drug_name_raw = row.get("Drug Name Raw") if "Drug Name Raw" in df.columns else None
|
||||
drug_name_std = row.get("Drug Name")
|
||||
|
||||
# If drug_name_std is NULL, use the raw drug name (uppercase)
|
||||
# This handles cases where the drug isn't in the drugnames.csv mapping
|
||||
if pd.isna(drug_name_std):
|
||||
if drug_name_raw is not None and not pd.isna(drug_name_raw):
|
||||
drug_name_std = str(drug_name_raw).upper().strip()
|
||||
else:
|
||||
drug_name_std = "UNKNOWN"
|
||||
|
||||
# Also clean up raw drug name for storage
|
||||
if drug_name_raw is not None and not pd.isna(drug_name_raw):
|
||||
drug_name_raw = str(drug_name_raw).strip()
|
||||
|
||||
# Get other values with null handling
|
||||
def get_value(col_name):
|
||||
if col_name not in df.columns:
|
||||
return None
|
||||
val = row[col_name]
|
||||
if pd.isna(val):
|
||||
return None
|
||||
elif hasattr(val, "strftime"):
|
||||
return val.strftime("%Y-%m-%d")
|
||||
return val
|
||||
|
||||
row_data = (
|
||||
get_value("UPID"),
|
||||
get_value("Provider Code"),
|
||||
get_value("PersonKey"),
|
||||
drug_name_raw,
|
||||
drug_name_std,
|
||||
get_value("Intervention Date"),
|
||||
get_value("Price Actual") or 0,
|
||||
get_value("OrganisationName"),
|
||||
get_value("Directory"),
|
||||
get_value("Treatment Function Code"),
|
||||
get_value("Additional Detail 1"),
|
||||
get_value("Additional Detail 2"),
|
||||
get_value("Additional Detail 3"),
|
||||
get_value("Additional Detail 4"),
|
||||
get_value("Additional Detail 5"),
|
||||
source_file
|
||||
)
|
||||
batch_data.append(row_data)
|
||||
|
||||
# Execute batch insert
|
||||
conn.executemany(insert_sql, batch_data)
|
||||
rows_inserted += len(batch_data)
|
||||
|
||||
# Report progress
|
||||
if progress_callback:
|
||||
progress_callback(rows_inserted, total_rows)
|
||||
|
||||
if rows_skipped > 0:
|
||||
logger.info(f"Skipped {rows_skipped:,} rows with missing UPID or Intervention Date")
|
||||
|
||||
return rows_inserted
|
||||
|
||||
|
||||
def delete_file_data(conn: sqlite3.Connection, source_file: str) -> int:
|
||||
"""
|
||||
Delete all data from a specific source file.
|
||||
|
||||
Used when re-processing a changed file.
|
||||
|
||||
Args:
|
||||
conn: Database connection.
|
||||
source_file: Source file path.
|
||||
|
||||
Returns:
|
||||
Number of rows deleted.
|
||||
"""
|
||||
cursor = conn.execute(
|
||||
"DELETE FROM fact_interventions WHERE source_file = ?",
|
||||
(source_file,)
|
||||
)
|
||||
return cursor.rowcount
|
||||
|
||||
|
||||
def load_patient_data(
|
||||
file_path: Path | str,
|
||||
db_manager: Optional[DatabaseManager] = None,
|
||||
paths: Optional[PathConfig] = None,
|
||||
batch_size: int = 5000,
|
||||
force: bool = False,
|
||||
progress_callback: Optional[Callable[[int, int], None]] = None
|
||||
) -> PatientDataLoadResult:
|
||||
"""
|
||||
Load patient data from CSV/Parquet file into fact_interventions table.
|
||||
|
||||
This is the main entry point for loading patient data. It:
|
||||
1. Calculates file hash to detect changes
|
||||
2. Checks if file was already processed (skip if unchanged)
|
||||
3. Loads and transforms data using FileDataLoader
|
||||
4. Inserts data into SQLite in batches
|
||||
5. Records processing status in processed_files table
|
||||
|
||||
Args:
|
||||
file_path: Path to CSV or Parquet file.
|
||||
db_manager: DatabaseManager instance. Uses default if not provided.
|
||||
paths: PathConfig for reference data. Uses default if not provided.
|
||||
batch_size: Number of rows to insert per batch (default: 5000).
|
||||
force: If True, re-process even if file hash matches.
|
||||
progress_callback: Optional callback(rows_inserted, total_rows) for progress.
|
||||
|
||||
Returns:
|
||||
PatientDataLoadResult with loading statistics.
|
||||
"""
|
||||
if db_manager is None:
|
||||
db_manager = DatabaseManager()
|
||||
if paths is None:
|
||||
paths = default_paths
|
||||
|
||||
file_path = Path(file_path)
|
||||
file_path_str = str(file_path.absolute())
|
||||
|
||||
logger.info(f"Starting patient data load from {file_path}")
|
||||
start_time = time.time()
|
||||
|
||||
# Check file exists
|
||||
if not file_path.exists():
|
||||
error_msg = f"File not found: {file_path}"
|
||||
logger.error(error_msg)
|
||||
return PatientDataLoadResult(
|
||||
file_path=file_path_str,
|
||||
file_hash="",
|
||||
rows_read=0,
|
||||
rows_inserted=0,
|
||||
rows_skipped=0,
|
||||
success=False,
|
||||
error_message=error_msg
|
||||
)
|
||||
|
||||
# Calculate file hash
|
||||
logger.info("Calculating file hash...")
|
||||
file_hash = calculate_file_hash(file_path)
|
||||
file_size = file_path.stat().st_size
|
||||
file_modified = datetime.fromtimestamp(file_path.stat().st_mtime)
|
||||
|
||||
logger.info(f"File hash: {file_hash[:16]}... Size: {file_size:,} bytes")
|
||||
|
||||
# Check if already processed
|
||||
if not force:
|
||||
with db_manager.get_connection() as conn:
|
||||
is_processed, old_hash = check_file_processed(conn, file_path_str, file_hash)
|
||||
if is_processed:
|
||||
logger.info(f"File already processed with same hash, skipping")
|
||||
return PatientDataLoadResult(
|
||||
file_path=file_path_str,
|
||||
file_hash=file_hash,
|
||||
rows_read=0,
|
||||
rows_inserted=0,
|
||||
rows_skipped=0,
|
||||
success=True,
|
||||
was_already_processed=True
|
||||
)
|
||||
elif old_hash is not None:
|
||||
logger.info(f"File hash changed, will re-process (old: {old_hash[:16]}...)")
|
||||
|
||||
try:
|
||||
# Use FileDataLoader to load and transform data
|
||||
from data_processing.loader import FileDataLoader
|
||||
|
||||
loader = FileDataLoader(file_path, paths)
|
||||
logger.info("Loading and transforming data...")
|
||||
result = loader.load()
|
||||
df = result.df
|
||||
rows_read = result.row_count
|
||||
|
||||
logger.info(f"Loaded {rows_read:,} rows, starting SQLite insert...")
|
||||
|
||||
# Load into SQLite
|
||||
with db_manager.get_transaction() as conn:
|
||||
# Record that we're starting
|
||||
record_file_processing_start(conn, file_path_str, file_hash, file_size, file_modified)
|
||||
|
||||
# Delete any existing data from this file (for re-processing)
|
||||
deleted = delete_file_data(conn, file_path_str)
|
||||
if deleted > 0:
|
||||
logger.info(f"Deleted {deleted:,} existing rows from previous load")
|
||||
|
||||
# Insert new data
|
||||
rows_inserted = load_dataframe_to_sqlite(
|
||||
df, conn, file_path_str, batch_size, progress_callback
|
||||
)
|
||||
|
||||
# Record success
|
||||
load_time = time.time() - start_time
|
||||
record_file_processing_complete(
|
||||
conn, file_path_str, rows_inserted, load_time, True
|
||||
)
|
||||
|
||||
logger.info(f"Successfully loaded {rows_inserted:,} rows in {load_time:.1f}s")
|
||||
|
||||
return PatientDataLoadResult(
|
||||
file_path=file_path_str,
|
||||
file_hash=file_hash,
|
||||
rows_read=rows_read,
|
||||
rows_inserted=rows_inserted,
|
||||
rows_skipped=rows_read - rows_inserted,
|
||||
success=True,
|
||||
load_time_seconds=load_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
load_time = time.time() - start_time
|
||||
error_msg = str(e)
|
||||
logger.error(f"Failed to load patient data: {error_msg}")
|
||||
|
||||
# Record failure
|
||||
try:
|
||||
with db_manager.get_connection() as conn:
|
||||
record_file_processing_complete(
|
||||
conn, file_path_str, 0, load_time, False, error_msg
|
||||
)
|
||||
except Exception:
|
||||
pass # Don't fail on failure to record failure
|
||||
|
||||
return PatientDataLoadResult(
|
||||
file_path=file_path_str,
|
||||
file_hash=file_hash if 'file_hash' in dir() else "",
|
||||
rows_read=0,
|
||||
rows_inserted=0,
|
||||
rows_skipped=0,
|
||||
success=False,
|
||||
error_message=error_msg,
|
||||
load_time_seconds=load_time
|
||||
)
|
||||
|
||||
|
||||
def get_patient_data_stats(db_manager: Optional[DatabaseManager] = None) -> dict:
|
||||
"""
|
||||
Get statistics about patient data in fact_interventions.
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics about the loaded data.
|
||||
"""
|
||||
if db_manager is None:
|
||||
db_manager = DatabaseManager()
|
||||
|
||||
stats = {}
|
||||
|
||||
with db_manager.get_connection() as conn:
|
||||
# Total rows
|
||||
cursor = conn.execute("SELECT COUNT(*) FROM fact_interventions")
|
||||
stats["total_rows"] = cursor.fetchone()[0]
|
||||
|
||||
# Unique patients
|
||||
cursor = conn.execute("SELECT COUNT(DISTINCT upid) FROM fact_interventions")
|
||||
stats["unique_patients"] = cursor.fetchone()[0]
|
||||
|
||||
# Unique drugs
|
||||
cursor = conn.execute("SELECT COUNT(DISTINCT drug_name_std) FROM fact_interventions")
|
||||
stats["unique_drugs"] = cursor.fetchone()[0]
|
||||
|
||||
# Unique organizations
|
||||
cursor = conn.execute("SELECT COUNT(DISTINCT org_name) FROM fact_interventions")
|
||||
stats["unique_organizations"] = cursor.fetchone()[0]
|
||||
|
||||
# Date range
|
||||
cursor = conn.execute("""
|
||||
SELECT MIN(intervention_date), MAX(intervention_date)
|
||||
FROM fact_interventions
|
||||
""")
|
||||
result = cursor.fetchone()
|
||||
stats["date_range"] = (result[0], result[1]) if result else (None, None)
|
||||
|
||||
# Processed files
|
||||
cursor = conn.execute("""
|
||||
SELECT COUNT(*), SUM(row_count)
|
||||
FROM processed_files WHERE status = 'success'
|
||||
""")
|
||||
result = cursor.fetchone()
|
||||
stats["processed_files"] = result[0] if result else 0
|
||||
stats["processed_rows"] = result[1] if result and result[1] else 0
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def list_processed_files(db_manager: Optional[DatabaseManager] = None) -> list[dict]:
|
||||
"""
|
||||
List all processed files and their status.
|
||||
|
||||
Returns:
|
||||
List of dictionaries with file processing information.
|
||||
"""
|
||||
if db_manager is None:
|
||||
db_manager = DatabaseManager()
|
||||
|
||||
files = []
|
||||
|
||||
with db_manager.get_connection() as conn:
|
||||
cursor = conn.execute("""
|
||||
SELECT file_path, file_name, file_hash, file_size_bytes,
|
||||
row_count, status, error_message,
|
||||
first_processed_at, last_processed_at, processing_duration_seconds
|
||||
FROM processed_files
|
||||
ORDER BY last_processed_at DESC
|
||||
""")
|
||||
|
||||
for row in cursor.fetchall():
|
||||
files.append({
|
||||
"file_path": row["file_path"],
|
||||
"file_name": row["file_name"],
|
||||
"file_hash": row["file_hash"],
|
||||
"file_size_bytes": row["file_size_bytes"],
|
||||
"row_count": row["row_count"],
|
||||
"status": row["status"],
|
||||
"error_message": row["error_message"],
|
||||
"first_processed_at": row["first_processed_at"],
|
||||
"last_processed_at": row["last_processed_at"],
|
||||
"processing_duration_seconds": row["processing_duration_seconds"],
|
||||
})
|
||||
|
||||
return files
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Materialized View Refresh Functions
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class MVRefreshResult:
|
||||
"""Results from refreshing the patient treatment summary materialized view."""
|
||||
patients_processed: int
|
||||
rows_inserted: int
|
||||
refresh_time_seconds: float
|
||||
success: bool
|
||||
error_message: Optional[str] = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.success:
|
||||
return (
|
||||
f"Refreshed MV: {self.patients_processed:,} patients "
|
||||
f"in {self.refresh_time_seconds:.1f}s"
|
||||
)
|
||||
else:
|
||||
return f"MV refresh FAILED: {self.error_message}"
|
||||
|
||||
|
||||
def refresh_patient_treatment_summary(
|
||||
db_manager: Optional[DatabaseManager] = None,
|
||||
progress_callback: Optional[Callable[[int, int], None]] = None
|
||||
) -> MVRefreshResult:
|
||||
"""
|
||||
Refresh the mv_patient_treatment_summary materialized view.
|
||||
|
||||
This computes per-patient aggregations from fact_interventions:
|
||||
- First/last seen dates
|
||||
- Total cost, average cost per intervention
|
||||
- Intervention count, unique drug count
|
||||
- Drug sequence (chronological, pipe-separated)
|
||||
- Drug counts, costs, and date ranges (as JSON)
|
||||
|
||||
The MV is fully rebuilt (truncate and re-insert) for simplicity.
|
||||
This typically takes 30-60 seconds for ~35,000 patients.
|
||||
|
||||
Args:
|
||||
db_manager: DatabaseManager instance. Uses default if not provided.
|
||||
progress_callback: Optional callback(patients_done, total_patients).
|
||||
|
||||
Returns:
|
||||
MVRefreshResult with refresh statistics.
|
||||
"""
|
||||
if db_manager is None:
|
||||
db_manager = DatabaseManager()
|
||||
|
||||
logger.info("Starting materialized view refresh...")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
with db_manager.get_transaction() as conn:
|
||||
# Step 1: Get total patient count for progress reporting
|
||||
cursor = conn.execute("SELECT COUNT(DISTINCT upid) FROM fact_interventions")
|
||||
total_patients = cursor.fetchone()[0]
|
||||
logger.info(f"Processing {total_patients:,} unique patients")
|
||||
|
||||
if total_patients == 0:
|
||||
logger.warning("No patient data in fact_interventions, MV will be empty")
|
||||
return MVRefreshResult(
|
||||
patients_processed=0,
|
||||
rows_inserted=0,
|
||||
refresh_time_seconds=time.time() - start_time,
|
||||
success=True
|
||||
)
|
||||
|
||||
# Step 2: Clear existing MV data
|
||||
conn.execute("DELETE FROM mv_patient_treatment_summary")
|
||||
logger.info("Cleared existing MV data")
|
||||
|
||||
# Step 3: Compute aggregations using SQL CTEs
|
||||
# This is more efficient than processing row-by-row in Python
|
||||
refresh_sql = """
|
||||
WITH patient_aggs AS (
|
||||
-- Basic aggregations per patient
|
||||
SELECT
|
||||
upid,
|
||||
MIN(org_name) as org_name,
|
||||
MIN(directory) as directory,
|
||||
MIN(intervention_date) as first_seen_date,
|
||||
MAX(intervention_date) as last_seen_date,
|
||||
JULIANDAY(MAX(intervention_date)) - JULIANDAY(MIN(intervention_date)) as days_treated,
|
||||
SUM(price_actual) as total_cost,
|
||||
AVG(price_actual) as avg_cost_per_intervention,
|
||||
COUNT(*) as intervention_count,
|
||||
COUNT(DISTINCT drug_name_std) as unique_drug_count,
|
||||
COUNT(*) as source_row_count
|
||||
FROM fact_interventions
|
||||
GROUP BY upid
|
||||
),
|
||||
drug_sequences AS (
|
||||
-- Drug sequence per patient (chronological order, pipe-separated)
|
||||
SELECT
|
||||
upid,
|
||||
GROUP_CONCAT(drug_name_std, '|') as drug_sequence
|
||||
FROM (
|
||||
SELECT DISTINCT
|
||||
upid,
|
||||
drug_name_std,
|
||||
MIN(intervention_date) as first_date
|
||||
FROM fact_interventions
|
||||
GROUP BY upid, drug_name_std
|
||||
ORDER BY upid, first_date
|
||||
)
|
||||
GROUP BY upid
|
||||
),
|
||||
drug_counts AS (
|
||||
-- JSON object of drug counts per patient
|
||||
SELECT
|
||||
upid,
|
||||
'{' || GROUP_CONCAT('"' || drug_name_std || '": ' || cnt, ', ') || '}' as drug_counts_json
|
||||
FROM (
|
||||
SELECT
|
||||
upid,
|
||||
drug_name_std,
|
||||
COUNT(*) as cnt
|
||||
FROM fact_interventions
|
||||
GROUP BY upid, drug_name_std
|
||||
)
|
||||
GROUP BY upid
|
||||
),
|
||||
drug_costs AS (
|
||||
-- JSON object of drug costs per patient
|
||||
SELECT
|
||||
upid,
|
||||
'{' || GROUP_CONCAT('"' || drug_name_std || '": ' || ROUND(total_cost, 2), ', ') || '}' as drug_costs_json
|
||||
FROM (
|
||||
SELECT
|
||||
upid,
|
||||
drug_name_std,
|
||||
SUM(price_actual) as total_cost
|
||||
FROM fact_interventions
|
||||
GROUP BY upid, drug_name_std
|
||||
)
|
||||
GROUP BY upid
|
||||
),
|
||||
drug_dates AS (
|
||||
-- JSON object of drug date ranges per patient
|
||||
SELECT
|
||||
upid,
|
||||
'{' || GROUP_CONCAT('"' || drug_name_std || '": {"first": "' || first_date || '", "last": "' || last_date || '"}', ', ') || '}' as drug_date_ranges_json
|
||||
FROM (
|
||||
SELECT
|
||||
upid,
|
||||
drug_name_std,
|
||||
MIN(intervention_date) as first_date,
|
||||
MAX(intervention_date) as last_date
|
||||
FROM fact_interventions
|
||||
GROUP BY upid, drug_name_std
|
||||
)
|
||||
GROUP BY upid
|
||||
)
|
||||
INSERT INTO mv_patient_treatment_summary (
|
||||
upid, org_name, directory,
|
||||
first_seen_date, last_seen_date, days_treated,
|
||||
total_cost, avg_cost_per_intervention,
|
||||
intervention_count, unique_drug_count,
|
||||
drug_sequence, drug_counts_json, drug_costs_json, drug_date_ranges_json,
|
||||
source_row_count, computed_at
|
||||
)
|
||||
SELECT
|
||||
pa.upid,
|
||||
pa.org_name,
|
||||
pa.directory,
|
||||
pa.first_seen_date,
|
||||
pa.last_seen_date,
|
||||
CAST(pa.days_treated AS INTEGER),
|
||||
pa.total_cost,
|
||||
pa.avg_cost_per_intervention,
|
||||
pa.intervention_count,
|
||||
pa.unique_drug_count,
|
||||
ds.drug_sequence,
|
||||
dc.drug_counts_json,
|
||||
dco.drug_costs_json,
|
||||
dd.drug_date_ranges_json,
|
||||
pa.source_row_count,
|
||||
CURRENT_TIMESTAMP
|
||||
FROM patient_aggs pa
|
||||
LEFT JOIN drug_sequences ds ON pa.upid = ds.upid
|
||||
LEFT JOIN drug_counts dc ON pa.upid = dc.upid
|
||||
LEFT JOIN drug_costs dco ON pa.upid = dco.upid
|
||||
LEFT JOIN drug_dates dd ON pa.upid = dd.upid
|
||||
"""
|
||||
|
||||
logger.info("Executing MV refresh query...")
|
||||
conn.execute(refresh_sql)
|
||||
|
||||
# Get actual rows inserted
|
||||
cursor = conn.execute("SELECT COUNT(*) FROM mv_patient_treatment_summary")
|
||||
rows_inserted = cursor.fetchone()[0]
|
||||
|
||||
refresh_time = time.time() - start_time
|
||||
logger.info(f"MV refresh complete: {rows_inserted:,} rows in {refresh_time:.1f}s")
|
||||
|
||||
# Report progress if callback provided
|
||||
if progress_callback:
|
||||
progress_callback(rows_inserted, total_patients)
|
||||
|
||||
return MVRefreshResult(
|
||||
patients_processed=total_patients,
|
||||
rows_inserted=rows_inserted,
|
||||
refresh_time_seconds=refresh_time,
|
||||
success=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
refresh_time = time.time() - start_time
|
||||
error_msg = str(e)
|
||||
logger.error(f"MV refresh failed: {error_msg}")
|
||||
return MVRefreshResult(
|
||||
patients_processed=0,
|
||||
rows_inserted=0,
|
||||
refresh_time_seconds=refresh_time,
|
||||
success=False,
|
||||
error_message=error_msg
|
||||
)
|
||||
|
||||
|
||||
def get_patient_summary_stats(db_manager: Optional[DatabaseManager] = None) -> dict:
|
||||
"""
|
||||
Get statistics about the patient treatment summary MV.
|
||||
|
||||
Returns:
|
||||
Dictionary with MV statistics.
|
||||
"""
|
||||
if db_manager is None:
|
||||
db_manager = DatabaseManager()
|
||||
|
||||
stats = {}
|
||||
|
||||
with db_manager.get_connection() as conn:
|
||||
# Total rows
|
||||
cursor = conn.execute("SELECT COUNT(*) FROM mv_patient_treatment_summary")
|
||||
stats["total_patients"] = cursor.fetchone()[0]
|
||||
|
||||
if stats["total_patients"] == 0:
|
||||
return stats
|
||||
|
||||
# Aggregated statistics
|
||||
cursor = conn.execute("""
|
||||
SELECT
|
||||
SUM(total_cost) as total_cost_all,
|
||||
AVG(total_cost) as avg_cost_per_patient,
|
||||
SUM(intervention_count) as total_interventions,
|
||||
AVG(intervention_count) as avg_interventions_per_patient,
|
||||
AVG(unique_drug_count) as avg_drugs_per_patient,
|
||||
AVG(days_treated) as avg_days_treated,
|
||||
MIN(first_seen_date) as earliest_date,
|
||||
MAX(last_seen_date) as latest_date,
|
||||
MAX(computed_at) as last_refresh
|
||||
FROM mv_patient_treatment_summary
|
||||
""")
|
||||
result = cursor.fetchone()
|
||||
|
||||
stats["total_cost"] = result[0] if result[0] else 0
|
||||
stats["avg_cost_per_patient"] = result[1] if result[1] else 0
|
||||
stats["total_interventions"] = result[2] if result[2] else 0
|
||||
stats["avg_interventions_per_patient"] = result[3] if result[3] else 0
|
||||
stats["avg_drugs_per_patient"] = result[4] if result[4] else 0
|
||||
stats["avg_days_treated"] = result[5] if result[5] else 0
|
||||
stats["date_range"] = (result[6], result[7])
|
||||
stats["last_refresh"] = result[8]
|
||||
|
||||
# Unique directories in MV
|
||||
cursor = conn.execute("SELECT COUNT(DISTINCT directory) FROM mv_patient_treatment_summary")
|
||||
stats["unique_directories"] = cursor.fetchone()[0]
|
||||
|
||||
# Unique organizations in MV
|
||||
cursor = conn.execute("SELECT COUNT(DISTINCT org_name) FROM mv_patient_treatment_summary")
|
||||
stats["unique_organizations"] = cursor.fetchone()[0]
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def verify_mv_consistency(db_manager: Optional[DatabaseManager] = None) -> tuple[bool, str]:
|
||||
"""
|
||||
Verify that the MV is consistent with fact_interventions.
|
||||
|
||||
Checks that:
|
||||
- Patient counts match
|
||||
- Total cost sums match
|
||||
- Intervention counts match
|
||||
|
||||
Returns:
|
||||
Tuple of (is_consistent, message).
|
||||
"""
|
||||
if db_manager is None:
|
||||
db_manager = DatabaseManager()
|
||||
|
||||
with db_manager.get_connection() as conn:
|
||||
# Get fact table counts
|
||||
cursor = conn.execute("""
|
||||
SELECT
|
||||
COUNT(DISTINCT upid) as patients,
|
||||
SUM(price_actual) as total_cost,
|
||||
COUNT(*) as interventions
|
||||
FROM fact_interventions
|
||||
""")
|
||||
fact_row = cursor.fetchone()
|
||||
fact_patients = fact_row[0] or 0
|
||||
fact_cost = fact_row[1] or 0
|
||||
fact_interventions = fact_row[2] or 0
|
||||
|
||||
# Get MV counts
|
||||
cursor = conn.execute("""
|
||||
SELECT
|
||||
COUNT(*) as patients,
|
||||
SUM(total_cost) as total_cost,
|
||||
SUM(intervention_count) as interventions
|
||||
FROM mv_patient_treatment_summary
|
||||
""")
|
||||
mv_row = cursor.fetchone()
|
||||
mv_patients = mv_row[0] or 0
|
||||
mv_cost = mv_row[1] or 0
|
||||
mv_interventions = mv_row[2] or 0
|
||||
|
||||
# Compare
|
||||
issues = []
|
||||
|
||||
if fact_patients != mv_patients:
|
||||
issues.append(f"Patient count mismatch: fact={fact_patients:,}, mv={mv_patients:,}")
|
||||
|
||||
if mv_interventions != fact_interventions:
|
||||
issues.append(f"Intervention count mismatch: fact={fact_interventions:,}, mv={mv_interventions:,}")
|
||||
|
||||
# Allow small floating point differences in cost
|
||||
cost_diff = abs(fact_cost - mv_cost)
|
||||
if cost_diff > 0.01:
|
||||
issues.append(f"Cost mismatch: fact={fact_cost:,.2f}, mv={mv_cost:,.2f}, diff={cost_diff:.2f}")
|
||||
|
||||
if issues:
|
||||
return False, "; ".join(issues)
|
||||
|
||||
return True, f"MV consistent: {mv_patients:,} patients, {mv_interventions:,} interventions, £{mv_cost:,.2f} total"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,665 @@
|
||||
"""
|
||||
SQLite schema definitions for NHS High-Cost Drug Patient Pathway Analysis Tool.
|
||||
|
||||
Contains SQL strings for creating reference tables, fact tables, and indexes.
|
||||
Schema design supports:
|
||||
- Reference data from CSV files (drug names, organizations, directories)
|
||||
- Drug-directory mappings with single-valid-directory flag
|
||||
- Patient intervention facts with proper indexing
|
||||
- Cached aggregations for performance
|
||||
- File tracking for incremental updates
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
import sqlite3
|
||||
|
||||
from core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Reference Table Schemas
|
||||
# =============================================================================
|
||||
|
||||
REF_DRUG_NAMES_SCHEMA = """
|
||||
-- Mapping from raw drug names (as they appear in source data) to standardized names
|
||||
-- Source: data/drugnames.csv
|
||||
CREATE TABLE IF NOT EXISTS ref_drug_names (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
raw_name TEXT NOT NULL UNIQUE,
|
||||
standard_name TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Index for fast lookups during data transformation
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_drug_names_raw ON ref_drug_names(raw_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_drug_names_standard ON ref_drug_names(standard_name);
|
||||
"""
|
||||
|
||||
REF_ORGANIZATIONS_SCHEMA = """
|
||||
-- NHS organization codes and names
|
||||
-- Source: data/org_codes.csv
|
||||
CREATE TABLE IF NOT EXISTS ref_organizations (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
org_code TEXT NOT NULL UNIQUE,
|
||||
org_name TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Index for fast lookups by organization code
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_organizations_code ON ref_organizations(org_code);
|
||||
"""
|
||||
|
||||
REF_DIRECTORIES_SCHEMA = """
|
||||
-- Medical directories/specialties
|
||||
-- Source: data/directory_list.csv
|
||||
CREATE TABLE IF NOT EXISTS ref_directories (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
directory_name TEXT NOT NULL UNIQUE,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Index for fast lookups by directory name
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_directories_name ON ref_directories(directory_name);
|
||||
"""
|
||||
|
||||
REF_DRUG_DIRECTORY_MAP_SCHEMA = """
|
||||
-- Mapping from drug names to valid directories
|
||||
-- Source: data/drug_directory_list.csv
|
||||
-- A drug may map to multiple directories (one row per drug-directory pair)
|
||||
-- The is_single_valid flag indicates drugs with exactly ONE valid directory,
|
||||
-- which enables automatic directory assignment in department_identification()
|
||||
CREATE TABLE IF NOT EXISTS ref_drug_directory_map (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
drug_name TEXT NOT NULL,
|
||||
directory_name TEXT NOT NULL,
|
||||
is_single_valid BOOLEAN NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(drug_name, directory_name)
|
||||
);
|
||||
|
||||
-- Index for looking up directories by drug name (most common access pattern)
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_drug_directory_map_drug ON ref_drug_directory_map(drug_name);
|
||||
|
||||
-- Index for reverse lookup (find drugs by directory)
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_drug_directory_map_directory ON ref_drug_directory_map(directory_name);
|
||||
|
||||
-- Index for quick filtering of single-valid drugs
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_drug_directory_map_single ON ref_drug_directory_map(is_single_valid);
|
||||
"""
|
||||
|
||||
REF_DRUG_INDICATION_CLUSTERS_SCHEMA = """
|
||||
-- Mapping from drugs to SNOMED clusters for indication validation
|
||||
-- Source: data/drug_indication_clusters.csv
|
||||
-- Used to validate that patients have appropriate GP diagnoses for their prescribed drugs
|
||||
-- A drug may map to multiple clusters (one row per drug-indication-cluster combination)
|
||||
CREATE TABLE IF NOT EXISTS ref_drug_indication_clusters (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
drug_name TEXT NOT NULL,
|
||||
indication TEXT NOT NULL,
|
||||
cluster_id TEXT NOT NULL,
|
||||
cluster_description TEXT,
|
||||
nice_ta_reference TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(drug_name, indication, cluster_id)
|
||||
);
|
||||
|
||||
-- Index for looking up clusters by drug name (most common access pattern)
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_drug_indication_clusters_drug ON ref_drug_indication_clusters(drug_name);
|
||||
|
||||
-- Index for looking up drugs by cluster (for finding all drugs treating a condition)
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_drug_indication_clusters_cluster ON ref_drug_indication_clusters(cluster_id);
|
||||
|
||||
-- Index for looking up by indication text
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_drug_indication_clusters_indication ON ref_drug_indication_clusters(indication);
|
||||
"""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Fact Table Schemas
|
||||
# =============================================================================
|
||||
|
||||
FACT_INTERVENTIONS_SCHEMA = """
|
||||
-- Patient intervention records (fact table)
|
||||
-- Source: HCD activity data (CSV/Parquet files or Snowflake)
|
||||
-- This is the main fact table storing all patient intervention events
|
||||
CREATE TABLE IF NOT EXISTS fact_interventions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
|
||||
-- Patient identification
|
||||
upid TEXT NOT NULL, -- Unique Patient ID (Provider Code[:3] + PersonKey)
|
||||
provider_code TEXT NOT NULL, -- Original provider code (3-5 chars)
|
||||
person_key TEXT NOT NULL, -- Patient key from source system
|
||||
|
||||
-- Intervention details
|
||||
drug_name_raw TEXT, -- Original drug name from source
|
||||
drug_name_std TEXT NOT NULL, -- Standardized drug name (via ref_drug_names)
|
||||
intervention_date DATE NOT NULL, -- Date of intervention
|
||||
price_actual REAL NOT NULL DEFAULT 0, -- Cost of intervention in GBP
|
||||
|
||||
-- Organization and directory
|
||||
org_name TEXT, -- Organization name (cleaned, no commas)
|
||||
directory TEXT, -- Medical directory/specialty (may be "Undefined")
|
||||
|
||||
-- Source tracking
|
||||
source_file TEXT, -- Original file this record came from
|
||||
loaded_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
-- Additional clinical fields (optional, used in directory fallback logic)
|
||||
treatment_function_code INTEGER,
|
||||
additional_detail_1 TEXT,
|
||||
additional_detail_2 TEXT,
|
||||
additional_detail_3 TEXT,
|
||||
additional_detail_4 TEXT,
|
||||
additional_detail_5 TEXT
|
||||
);
|
||||
|
||||
-- Primary indexes for common filter patterns used in generate_graph()
|
||||
-- UPID: Used for patient grouping, pathway analysis
|
||||
CREATE INDEX IF NOT EXISTS idx_fact_interventions_upid ON fact_interventions(upid);
|
||||
|
||||
-- Drug name (standardized): Used for drug filtering
|
||||
CREATE INDEX IF NOT EXISTS idx_fact_interventions_drug ON fact_interventions(drug_name_std);
|
||||
|
||||
-- Intervention date: Used for date range filtering (start_date, end_date, last_seen)
|
||||
CREATE INDEX IF NOT EXISTS idx_fact_interventions_date ON fact_interventions(intervention_date);
|
||||
|
||||
-- Directory: Used for directory/specialty filtering
|
||||
CREATE INDEX IF NOT EXISTS idx_fact_interventions_directory ON fact_interventions(directory);
|
||||
|
||||
-- Organization: Used for trust filtering (Provider Code maps to org_name)
|
||||
CREATE INDEX IF NOT EXISTS idx_fact_interventions_org ON fact_interventions(org_name);
|
||||
|
||||
-- Composite index for common filter combination (trust + drug + directory)
|
||||
CREATE INDEX IF NOT EXISTS idx_fact_interventions_composite
|
||||
ON fact_interventions(org_name, drug_name_std, directory);
|
||||
|
||||
-- Composite index for date-based patient analysis
|
||||
CREATE INDEX IF NOT EXISTS idx_fact_interventions_upid_date
|
||||
ON fact_interventions(upid, intervention_date);
|
||||
"""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Materialized View Schemas (Cached Aggregations)
|
||||
# =============================================================================
|
||||
|
||||
MV_PATIENT_TREATMENT_SUMMARY_SCHEMA = """
|
||||
-- Materialized view of patient treatment summaries
|
||||
-- Pre-computed aggregations per patient for faster pathway analysis
|
||||
-- Refreshed when fact_interventions data changes
|
||||
CREATE TABLE IF NOT EXISTS mv_patient_treatment_summary (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
|
||||
-- Patient identification
|
||||
upid TEXT NOT NULL UNIQUE, -- Unique Patient ID
|
||||
|
||||
-- Organization and directory (for filtering)
|
||||
org_name TEXT, -- Organization name (first org seen)
|
||||
directory TEXT, -- Primary directory (first directory assigned)
|
||||
|
||||
-- Date range
|
||||
first_seen_date DATE NOT NULL, -- First intervention date
|
||||
last_seen_date DATE NOT NULL, -- Last intervention date
|
||||
days_treated INTEGER NOT NULL DEFAULT 0, -- Duration: last_seen - first_seen
|
||||
|
||||
-- Cost aggregations
|
||||
total_cost REAL NOT NULL DEFAULT 0, -- Sum of all intervention costs
|
||||
avg_cost_per_intervention REAL, -- Average cost per intervention
|
||||
|
||||
-- Treatment summary
|
||||
intervention_count INTEGER NOT NULL DEFAULT 0, -- Total number of interventions
|
||||
unique_drug_count INTEGER NOT NULL DEFAULT 0, -- Number of distinct drugs
|
||||
|
||||
-- Drug sequence (pipe-separated standardized drug names in chronological order)
|
||||
-- Example: "ADALIMUMAB|ETANERCEPT|INFLIXIMAB"
|
||||
drug_sequence TEXT,
|
||||
|
||||
-- Drug frequency counts (JSON: {"ADALIMUMAB": 5, "ETANERCEPT": 3})
|
||||
-- Stores count of each drug for this patient
|
||||
drug_counts_json TEXT,
|
||||
|
||||
-- Drug cost totals (JSON: {"ADALIMUMAB": 15000.00, "ETANERCEPT": 8000.00})
|
||||
-- Stores total cost per drug for this patient
|
||||
drug_costs_json TEXT,
|
||||
|
||||
-- Per-drug date ranges (JSON: {"ADALIMUMAB": {"first": "2023-01-01", "last": "2023-06-15"}, ...})
|
||||
-- Stores first/last date for each drug
|
||||
drug_date_ranges_json TEXT,
|
||||
|
||||
-- Metadata
|
||||
computed_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
source_row_count INTEGER -- Number of fact_interventions rows used
|
||||
);
|
||||
|
||||
-- Index for fast patient lookup
|
||||
CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_upid ON mv_patient_treatment_summary(upid);
|
||||
|
||||
-- Indexes for common filter patterns
|
||||
CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_org ON mv_patient_treatment_summary(org_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_directory ON mv_patient_treatment_summary(directory);
|
||||
CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_first_seen ON mv_patient_treatment_summary(first_seen_date);
|
||||
CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_last_seen ON mv_patient_treatment_summary(last_seen_date);
|
||||
|
||||
-- Composite index for date range filtering (common in generate_graph)
|
||||
CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_date_range
|
||||
ON mv_patient_treatment_summary(first_seen_date, last_seen_date);
|
||||
|
||||
-- Composite index for org + directory + dates (full filter pattern)
|
||||
CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_filter_composite
|
||||
ON mv_patient_treatment_summary(org_name, directory, first_seen_date, last_seen_date);
|
||||
|
||||
-- Index for drug sequence pattern matching
|
||||
CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_drug_seq ON mv_patient_treatment_summary(drug_sequence);
|
||||
"""
|
||||
|
||||
MATERIALIZED_VIEWS_SCHEMA = f"""
|
||||
-- Materialized Views Schema
|
||||
-- Pre-computed aggregations for performance
|
||||
|
||||
{MV_PATIENT_TREATMENT_SUMMARY_SCHEMA}
|
||||
"""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# File Tracking Schemas (Incremental Updates)
|
||||
# =============================================================================
|
||||
|
||||
PROCESSED_FILES_SCHEMA = """
|
||||
-- Tracks processed data files for incremental updates
|
||||
-- Enables detecting changed files by comparing hashes
|
||||
-- Stores processing status and statistics
|
||||
CREATE TABLE IF NOT EXISTS processed_files (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
|
||||
-- File identification
|
||||
file_path TEXT NOT NULL, -- Full path to the file
|
||||
file_name TEXT NOT NULL, -- Just the filename (for display)
|
||||
file_hash TEXT NOT NULL, -- SHA256 hash of file contents
|
||||
|
||||
-- File metadata
|
||||
file_size_bytes INTEGER, -- Size of file in bytes
|
||||
file_modified_at TIMESTAMP, -- File's last modification timestamp
|
||||
|
||||
-- Processing results
|
||||
row_count INTEGER DEFAULT 0, -- Number of rows processed from this file
|
||||
status TEXT NOT NULL DEFAULT 'pending', -- pending, processing, success, error
|
||||
error_message TEXT, -- Error details if status='error'
|
||||
|
||||
-- Timestamps
|
||||
first_processed_at TIMESTAMP, -- When first processed
|
||||
last_processed_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
processing_duration_seconds REAL, -- How long processing took
|
||||
|
||||
-- Uniqueness: only one record per file path
|
||||
-- Hash changes indicate file content changed (needs reprocessing)
|
||||
UNIQUE(file_path)
|
||||
);
|
||||
|
||||
-- Index for fast lookup by file path
|
||||
CREATE INDEX IF NOT EXISTS idx_processed_files_path ON processed_files(file_path);
|
||||
|
||||
-- Index for finding files by status (e.g., find all pending or errored files)
|
||||
CREATE INDEX IF NOT EXISTS idx_processed_files_status ON processed_files(status);
|
||||
|
||||
-- Index for finding files by hash (detect if same file appears at different paths)
|
||||
CREATE INDEX IF NOT EXISTS idx_processed_files_hash ON processed_files(file_hash);
|
||||
|
||||
-- Index for finding recently processed files
|
||||
CREATE INDEX IF NOT EXISTS idx_processed_files_last_processed ON processed_files(last_processed_at);
|
||||
"""
|
||||
|
||||
FILE_TRACKING_SCHEMA = f"""
|
||||
-- File Tracking Schema
|
||||
-- Supports incremental data loading
|
||||
|
||||
{PROCESSED_FILES_SCHEMA}
|
||||
"""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Combined Schemas
|
||||
# =============================================================================
|
||||
|
||||
REFERENCE_TABLES_SCHEMA = f"""
|
||||
-- Reference Tables Schema
|
||||
-- Contains lookup data migrated from CSV files
|
||||
|
||||
{REF_DRUG_NAMES_SCHEMA}
|
||||
|
||||
{REF_ORGANIZATIONS_SCHEMA}
|
||||
|
||||
{REF_DIRECTORIES_SCHEMA}
|
||||
|
||||
{REF_DRUG_DIRECTORY_MAP_SCHEMA}
|
||||
|
||||
{REF_DRUG_INDICATION_CLUSTERS_SCHEMA}
|
||||
"""
|
||||
|
||||
FACT_TABLES_SCHEMA = f"""
|
||||
-- Fact Tables Schema
|
||||
-- Contains patient intervention data
|
||||
|
||||
{FACT_INTERVENTIONS_SCHEMA}
|
||||
"""
|
||||
|
||||
ALL_TABLES_SCHEMA = f"""
|
||||
-- Complete Database Schema
|
||||
-- Reference tables + Fact tables + Materialized views + File tracking
|
||||
|
||||
{REFERENCE_TABLES_SCHEMA}
|
||||
|
||||
{FACT_TABLES_SCHEMA}
|
||||
|
||||
{MATERIALIZED_VIEWS_SCHEMA}
|
||||
|
||||
{FILE_TRACKING_SCHEMA}
|
||||
"""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Schema Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
def create_reference_tables(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Create all reference tables in the database.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
"""
|
||||
logger.info("Creating reference tables...")
|
||||
conn.executescript(REFERENCE_TABLES_SCHEMA)
|
||||
logger.info("Reference tables created successfully")
|
||||
|
||||
|
||||
def drop_reference_tables(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Drop all reference tables from the database.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Warning:
|
||||
This will delete all reference data. Use with caution.
|
||||
"""
|
||||
logger.warning("Dropping reference tables...")
|
||||
conn.executescript("""
|
||||
DROP TABLE IF EXISTS ref_drug_names;
|
||||
DROP TABLE IF EXISTS ref_organizations;
|
||||
DROP TABLE IF EXISTS ref_directories;
|
||||
DROP TABLE IF EXISTS ref_drug_directory_map;
|
||||
DROP TABLE IF EXISTS ref_drug_indication_clusters;
|
||||
""")
|
||||
logger.info("Reference tables dropped")
|
||||
|
||||
|
||||
def get_reference_table_counts(conn: sqlite3.Connection) -> dict[str, int]:
|
||||
"""
|
||||
Get row counts for all reference tables.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping table name to row count.
|
||||
"""
|
||||
tables = ["ref_drug_names", "ref_organizations", "ref_directories", "ref_drug_directory_map", "ref_drug_indication_clusters"]
|
||||
counts = {}
|
||||
|
||||
for table in tables:
|
||||
cursor = conn.execute(f"SELECT COUNT(*) FROM {table}")
|
||||
result = cursor.fetchone()
|
||||
counts[table] = result[0] if result else 0
|
||||
|
||||
return counts
|
||||
|
||||
|
||||
def verify_reference_tables_exist(conn: sqlite3.Connection) -> list[str]:
|
||||
"""
|
||||
Verify that all reference tables exist.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Returns:
|
||||
List of missing table names. Empty list means all tables exist.
|
||||
"""
|
||||
required_tables = ["ref_drug_names", "ref_organizations", "ref_directories", "ref_drug_directory_map", "ref_drug_indication_clusters"]
|
||||
missing = []
|
||||
|
||||
for table in required_tables:
|
||||
cursor = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
||||
(table,)
|
||||
)
|
||||
if cursor.fetchone() is None:
|
||||
missing.append(table)
|
||||
|
||||
return missing
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Fact Table Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
def create_fact_tables(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Create all fact tables in the database (including materialized views).
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
"""
|
||||
logger.info("Creating fact tables...")
|
||||
conn.executescript(FACT_TABLES_SCHEMA)
|
||||
conn.executescript(MATERIALIZED_VIEWS_SCHEMA)
|
||||
logger.info("Fact tables created successfully")
|
||||
|
||||
|
||||
def drop_fact_tables(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Drop all fact tables from the database.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Warning:
|
||||
This will delete all patient intervention data. Use with caution.
|
||||
"""
|
||||
logger.warning("Dropping fact tables...")
|
||||
conn.executescript("""
|
||||
DROP TABLE IF EXISTS fact_interventions;
|
||||
DROP TABLE IF EXISTS mv_patient_treatment_summary;
|
||||
""")
|
||||
logger.info("Fact tables dropped")
|
||||
|
||||
|
||||
def get_fact_table_counts(conn: sqlite3.Connection) -> dict[str, int]:
|
||||
"""
|
||||
Get row counts for all fact tables (including materialized views).
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping table name to row count.
|
||||
"""
|
||||
tables = ["fact_interventions", "mv_patient_treatment_summary"]
|
||||
counts = {}
|
||||
|
||||
for table in tables:
|
||||
cursor = conn.execute(f"SELECT COUNT(*) FROM {table}")
|
||||
result = cursor.fetchone()
|
||||
counts[table] = result[0] if result else 0
|
||||
|
||||
return counts
|
||||
|
||||
|
||||
def verify_fact_tables_exist(conn: sqlite3.Connection) -> list[str]:
|
||||
"""
|
||||
Verify that all fact tables exist (including materialized views).
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Returns:
|
||||
List of missing table names. Empty list means all tables exist.
|
||||
"""
|
||||
required_tables = ["fact_interventions", "mv_patient_treatment_summary"]
|
||||
missing = []
|
||||
|
||||
for table in required_tables:
|
||||
cursor = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
||||
(table,)
|
||||
)
|
||||
if cursor.fetchone() is None:
|
||||
missing.append(table)
|
||||
|
||||
return missing
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# File Tracking Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
def create_file_tracking_tables(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Create file tracking tables in the database.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
"""
|
||||
logger.info("Creating file tracking tables...")
|
||||
conn.executescript(FILE_TRACKING_SCHEMA)
|
||||
logger.info("File tracking tables created successfully")
|
||||
|
||||
|
||||
def drop_file_tracking_tables(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Drop file tracking tables from the database.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Warning:
|
||||
This will delete all file tracking history.
|
||||
"""
|
||||
logger.warning("Dropping file tracking tables...")
|
||||
conn.executescript("""
|
||||
DROP TABLE IF EXISTS processed_files;
|
||||
""")
|
||||
logger.info("File tracking tables dropped")
|
||||
|
||||
|
||||
def get_file_tracking_counts(conn: sqlite3.Connection) -> dict[str, int]:
|
||||
"""
|
||||
Get row counts for file tracking tables.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping table name to row count.
|
||||
"""
|
||||
tables = ["processed_files"]
|
||||
counts = {}
|
||||
|
||||
for table in tables:
|
||||
cursor = conn.execute(f"SELECT COUNT(*) FROM {table}")
|
||||
result = cursor.fetchone()
|
||||
counts[table] = result[0] if result else 0
|
||||
|
||||
return counts
|
||||
|
||||
|
||||
def verify_file_tracking_tables_exist(conn: sqlite3.Connection) -> list[str]:
|
||||
"""
|
||||
Verify that file tracking tables exist.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Returns:
|
||||
List of missing table names. Empty list means all tables exist.
|
||||
"""
|
||||
required_tables = ["processed_files"]
|
||||
missing = []
|
||||
|
||||
for table in required_tables:
|
||||
cursor = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
||||
(table,)
|
||||
)
|
||||
if cursor.fetchone() is None:
|
||||
missing.append(table)
|
||||
|
||||
return missing
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Combined Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
def create_all_tables(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Create all tables (reference + fact) in the database.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
"""
|
||||
logger.info("Creating all database tables...")
|
||||
conn.executescript(ALL_TABLES_SCHEMA)
|
||||
logger.info("All tables created successfully")
|
||||
|
||||
|
||||
def drop_all_tables(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Drop all tables from the database.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Warning:
|
||||
This will delete all data. Use with extreme caution.
|
||||
"""
|
||||
logger.warning("Dropping all tables...")
|
||||
drop_file_tracking_tables(conn)
|
||||
drop_fact_tables(conn)
|
||||
drop_reference_tables(conn)
|
||||
logger.info("All tables dropped")
|
||||
|
||||
|
||||
def get_all_table_counts(conn: sqlite3.Connection) -> dict[str, int]:
|
||||
"""
|
||||
Get row counts for all tables.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping table name to row count.
|
||||
"""
|
||||
counts = {}
|
||||
counts.update(get_reference_table_counts(conn))
|
||||
counts.update(get_fact_table_counts(conn))
|
||||
counts.update(get_file_tracking_counts(conn))
|
||||
return counts
|
||||
|
||||
|
||||
def verify_all_tables_exist(conn: sqlite3.Connection) -> list[str]:
|
||||
"""
|
||||
Verify that all tables exist.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Returns:
|
||||
List of missing table names. Empty list means all tables exist.
|
||||
"""
|
||||
missing = []
|
||||
missing.extend(verify_reference_tables_exist(conn))
|
||||
missing.extend(verify_fact_tables_exist(conn))
|
||||
missing.extend(verify_file_tracking_tables_exist(conn))
|
||||
return missing
|
||||
@@ -0,0 +1,797 @@
|
||||
"""
|
||||
Snowflake connector module for NHS Patient Pathway Analysis.
|
||||
|
||||
Provides connection handling with SSO browser authentication for NHS environments.
|
||||
Uses the externalbrowser authenticator which opens a browser window for NHS identity
|
||||
management authentication.
|
||||
|
||||
Usage:
|
||||
from data_processing.snowflake_connector import SnowflakeConnector, get_connector
|
||||
|
||||
# Using context manager (recommended)
|
||||
with get_connector() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM table LIMIT 10")
|
||||
results = cursor.fetchall()
|
||||
|
||||
# Manual connection management
|
||||
connector = SnowflakeConnector()
|
||||
try:
|
||||
conn = connector.connect()
|
||||
cursor = conn.cursor()
|
||||
# ... use cursor ...
|
||||
finally:
|
||||
connector.close()
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator, Optional, TYPE_CHECKING
|
||||
import time
|
||||
|
||||
# Snowflake connector is an optional dependency
|
||||
SNOWFLAKE_AVAILABLE = False
|
||||
try:
|
||||
import snowflake.connector
|
||||
from snowflake.connector import SnowflakeConnection
|
||||
from snowflake.connector.cursor import SnowflakeCursor
|
||||
SNOWFLAKE_AVAILABLE = True
|
||||
except ImportError:
|
||||
snowflake = None # type: ignore[assignment]
|
||||
|
||||
# Type hints for when snowflake is not available
|
||||
if TYPE_CHECKING:
|
||||
from snowflake.connector import SnowflakeConnection
|
||||
from snowflake.connector.cursor import SnowflakeCursor
|
||||
|
||||
from config import get_snowflake_config, SnowflakeConfig
|
||||
from core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SnowflakeConnectionError(Exception):
|
||||
"""Raised when Snowflake connection fails."""
|
||||
pass
|
||||
|
||||
|
||||
class SnowflakeNotConfiguredError(Exception):
|
||||
"""Raised when Snowflake is not configured (no account)."""
|
||||
pass
|
||||
|
||||
|
||||
class SnowflakeNotAvailableError(Exception):
|
||||
"""Raised when snowflake-connector-python is not installed."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionInfo:
|
||||
"""Information about the current connection state."""
|
||||
connected: bool = False
|
||||
account: str = ""
|
||||
warehouse: str = ""
|
||||
database: str = ""
|
||||
schema: str = ""
|
||||
user: str = ""
|
||||
role: str = ""
|
||||
connected_at: Optional[datetime] = None
|
||||
last_query_at: Optional[datetime] = None
|
||||
query_count: int = 0
|
||||
|
||||
|
||||
class SnowflakeConnector:
|
||||
"""
|
||||
Manages Snowflake connections with SSO browser authentication.
|
||||
|
||||
This class provides connection management for NHS Snowflake access using
|
||||
the externalbrowser authenticator which triggers NHS SSO login via browser.
|
||||
|
||||
Attributes:
|
||||
config: SnowflakeConfig with connection settings
|
||||
connection_info: ConnectionInfo tracking current state
|
||||
|
||||
Example:
|
||||
connector = SnowflakeConnector()
|
||||
with connector.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT CURRENT_USER()")
|
||||
print(cursor.fetchone()[0])
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[SnowflakeConfig] = None):
|
||||
"""
|
||||
Initialize the connector with configuration.
|
||||
|
||||
Args:
|
||||
config: Optional SnowflakeConfig. If not provided, loads from
|
||||
config/snowflake.toml using get_snowflake_config().
|
||||
"""
|
||||
self._config = config or get_snowflake_config()
|
||||
self._connection: Optional[SnowflakeConnection] = None
|
||||
self._connection_info = ConnectionInfo()
|
||||
|
||||
@property
|
||||
def config(self) -> SnowflakeConfig:
|
||||
"""Return the Snowflake configuration."""
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def connection_info(self) -> ConnectionInfo:
|
||||
"""Return information about the current connection state."""
|
||||
return self._connection_info
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Return True if currently connected to Snowflake."""
|
||||
return self._connection is not None and not self._connection.is_closed()
|
||||
|
||||
def _check_availability(self) -> None:
|
||||
"""Check that snowflake-connector-python is installed."""
|
||||
if not SNOWFLAKE_AVAILABLE:
|
||||
raise SnowflakeNotAvailableError(
|
||||
"snowflake-connector-python is not installed. "
|
||||
"Install it with: pip install snowflake-connector-python"
|
||||
)
|
||||
|
||||
def _check_configured(self) -> None:
|
||||
"""Check that Snowflake is configured."""
|
||||
if not self._config.is_configured:
|
||||
raise SnowflakeNotConfiguredError(
|
||||
"Snowflake account is not configured. "
|
||||
"Edit config/snowflake.toml and set connection.account"
|
||||
)
|
||||
|
||||
def connect(self) -> SnowflakeConnection:
|
||||
"""
|
||||
Establish a connection to Snowflake.
|
||||
|
||||
Uses the externalbrowser authenticator which opens a browser window
|
||||
for NHS SSO authentication. The browser popup is expected and normal.
|
||||
|
||||
Returns:
|
||||
Active SnowflakeConnection
|
||||
|
||||
Raises:
|
||||
SnowflakeNotAvailableError: If snowflake-connector-python not installed
|
||||
SnowflakeNotConfiguredError: If account is not configured
|
||||
SnowflakeConnectionError: If connection fails
|
||||
"""
|
||||
self._check_availability()
|
||||
self._check_configured()
|
||||
|
||||
# Close existing connection if any
|
||||
if self._connection is not None:
|
||||
self.close()
|
||||
|
||||
conn_cfg = self._config.connection
|
||||
timeout_cfg = self._config.timeouts
|
||||
|
||||
logger.info(f"Connecting to Snowflake account: {conn_cfg.account}")
|
||||
logger.info(f"Using warehouse: {conn_cfg.warehouse}, database: {conn_cfg.database}")
|
||||
logger.info(f"Authenticator: {conn_cfg.authenticator}")
|
||||
if conn_cfg.authenticator == "externalbrowser":
|
||||
logger.info("Browser window will open for NHS SSO authentication")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Build connection parameters
|
||||
connect_params = {
|
||||
"account": conn_cfg.account,
|
||||
"warehouse": conn_cfg.warehouse,
|
||||
"database": conn_cfg.database,
|
||||
"schema": conn_cfg.schema,
|
||||
"authenticator": conn_cfg.authenticator,
|
||||
"login_timeout": timeout_cfg.login_timeout,
|
||||
"network_timeout": timeout_cfg.connection_timeout,
|
||||
}
|
||||
|
||||
# Optional parameters (only add if set)
|
||||
if conn_cfg.user:
|
||||
connect_params["user"] = conn_cfg.user
|
||||
if conn_cfg.role:
|
||||
connect_params["role"] = conn_cfg.role
|
||||
|
||||
self._connection = snowflake.connector.connect(**connect_params)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"Connected to Snowflake successfully in {elapsed:.1f}s")
|
||||
|
||||
# Update connection info
|
||||
self._connection_info = ConnectionInfo(
|
||||
connected=True,
|
||||
account=conn_cfg.account,
|
||||
warehouse=conn_cfg.warehouse,
|
||||
database=conn_cfg.database,
|
||||
schema=conn_cfg.schema,
|
||||
user=self._get_current_user(),
|
||||
role=self._get_current_role(),
|
||||
connected_at=datetime.now(),
|
||||
query_count=0,
|
||||
)
|
||||
|
||||
return self._connection
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.time() - start_time
|
||||
logger.error(f"Failed to connect to Snowflake after {elapsed:.1f}s: {e}")
|
||||
self._connection_info = ConnectionInfo(connected=False)
|
||||
raise SnowflakeConnectionError(f"Failed to connect to Snowflake: {e}") from e
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the Snowflake connection if open."""
|
||||
if self._connection is not None:
|
||||
try:
|
||||
self._connection.close()
|
||||
logger.info("Snowflake connection closed")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing Snowflake connection: {e}")
|
||||
finally:
|
||||
self._connection = None
|
||||
self._connection_info = ConnectionInfo(connected=False)
|
||||
|
||||
def _get_current_user(self) -> str:
|
||||
"""Get the current authenticated user."""
|
||||
if self._connection is None:
|
||||
return ""
|
||||
try:
|
||||
cursor = self._connection.cursor()
|
||||
cursor.execute("SELECT CURRENT_USER()")
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else ""
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def _get_current_role(self) -> str:
|
||||
"""Get the current active role."""
|
||||
if self._connection is None:
|
||||
return ""
|
||||
try:
|
||||
cursor = self._connection.cursor()
|
||||
cursor.execute("SELECT CURRENT_ROLE()")
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else ""
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self) -> Generator[SnowflakeConnection, None, None]:
|
||||
"""
|
||||
Context manager for connection handling.
|
||||
|
||||
Creates a new connection if not already connected, yields the connection,
|
||||
and ensures proper cleanup on exit.
|
||||
|
||||
Yields:
|
||||
Active SnowflakeConnection
|
||||
|
||||
Example:
|
||||
connector = SnowflakeConnector()
|
||||
with connector.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
assert self._connection is not None, "Connection should be established"
|
||||
try:
|
||||
yield self._connection
|
||||
finally:
|
||||
# Keep connection open for reuse
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def get_cursor(
|
||||
self,
|
||||
dict_cursor: bool = False
|
||||
) -> Generator[SnowflakeCursor, None, None]:
|
||||
"""
|
||||
Context manager that provides a cursor.
|
||||
|
||||
Args:
|
||||
dict_cursor: If True, returns cursor that yields dict-like rows
|
||||
|
||||
Yields:
|
||||
SnowflakeCursor for executing queries
|
||||
|
||||
Example:
|
||||
connector = SnowflakeConnector()
|
||||
with connector.get_cursor() as cursor:
|
||||
cursor.execute("SELECT * FROM table LIMIT 10")
|
||||
for row in cursor:
|
||||
print(row)
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
assert self._connection is not None, "Connection should be established"
|
||||
cursor: Any = None
|
||||
try:
|
||||
if dict_cursor:
|
||||
cursor = self._connection.cursor(snowflake.connector.DictCursor) # type: ignore[union-attr]
|
||||
else:
|
||||
cursor = self._connection.cursor()
|
||||
yield cursor # type: ignore[misc]
|
||||
self._connection_info.last_query_at = datetime.now()
|
||||
self._connection_info.query_count += 1
|
||||
finally:
|
||||
if cursor is not None:
|
||||
cursor.close()
|
||||
|
||||
def execute(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[tuple] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> list[tuple]:
|
||||
"""
|
||||
Execute a query and return all results.
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
params: Optional query parameters for parameterized queries
|
||||
timeout: Optional query timeout in seconds (overrides config)
|
||||
|
||||
Returns:
|
||||
List of result rows as tuples
|
||||
|
||||
Raises:
|
||||
SnowflakeConnectionError: If not connected
|
||||
Various snowflake errors for query issues
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
effective_timeout = timeout or self._config.timeouts.query_timeout
|
||||
|
||||
with self.get_cursor() as cursor:
|
||||
logger.info(f"Executing query (timeout={effective_timeout}s)")
|
||||
logger.debug(f"Query: {query[:200]}...")
|
||||
|
||||
if effective_timeout > 0:
|
||||
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
|
||||
|
||||
start_time = time.time()
|
||||
cursor.execute(query, params)
|
||||
results = cursor.fetchall()
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
logger.info(f"Query returned {len(results)} rows in {elapsed:.2f}s")
|
||||
return results
|
||||
|
||||
def execute_dict(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[tuple] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Execute a query and return results as list of dictionaries.
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
params: Optional query parameters
|
||||
timeout: Optional query timeout in seconds
|
||||
|
||||
Returns:
|
||||
List of result rows as dictionaries
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
effective_timeout = timeout or self._config.timeouts.query_timeout
|
||||
|
||||
with self.get_cursor(dict_cursor=True) as cursor:
|
||||
logger.info(f"Executing query (timeout={effective_timeout}s)")
|
||||
logger.debug(f"Query: {query[:200]}...")
|
||||
|
||||
if effective_timeout > 0:
|
||||
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
|
||||
|
||||
start_time = time.time()
|
||||
cursor.execute(query, params)
|
||||
results = cursor.fetchall()
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
logger.info(f"Query returned {len(results)} rows in {elapsed:.2f}s")
|
||||
return results # type: ignore[return-value]
|
||||
|
||||
def execute_chunked(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[tuple] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
timeout: Optional[int] = None,
|
||||
max_rows: Optional[int] = None,
|
||||
) -> Generator[list[tuple], None, None]:
|
||||
"""
|
||||
Execute a query and yield results in chunks for memory efficiency.
|
||||
|
||||
This method is useful for large result sets that would exceed memory
|
||||
if loaded all at once. Results are yielded as chunks of rows.
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
params: Optional query parameters for parameterized queries
|
||||
chunk_size: Number of rows per chunk (default from config)
|
||||
timeout: Optional query timeout in seconds (overrides config)
|
||||
max_rows: Maximum total rows to return (default from config, 0 for no limit)
|
||||
|
||||
Yields:
|
||||
List of result rows as tuples for each chunk
|
||||
|
||||
Example:
|
||||
for chunk in connector.execute_chunked("SELECT * FROM large_table"):
|
||||
process_chunk(chunk)
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
effective_timeout = timeout or self._config.timeouts.query_timeout
|
||||
effective_chunk_size = chunk_size or self._config.query.chunk_size
|
||||
effective_max_rows = max_rows if max_rows is not None else self._config.query.max_rows
|
||||
|
||||
with self.get_cursor() as cursor:
|
||||
logger.info(f"Executing chunked query (chunk_size={effective_chunk_size}, timeout={effective_timeout}s)")
|
||||
logger.debug(f"Query: {query[:200]}...")
|
||||
|
||||
if effective_timeout > 0:
|
||||
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
|
||||
|
||||
start_time = time.time()
|
||||
cursor.execute(query, params)
|
||||
|
||||
total_rows = 0
|
||||
chunk_num = 0
|
||||
|
||||
while True:
|
||||
# Determine how many rows to fetch this chunk
|
||||
if effective_max_rows > 0:
|
||||
remaining = effective_max_rows - total_rows
|
||||
if remaining <= 0:
|
||||
break
|
||||
fetch_size = min(effective_chunk_size, remaining)
|
||||
else:
|
||||
fetch_size = effective_chunk_size
|
||||
|
||||
chunk = cursor.fetchmany(fetch_size)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
chunk_num += 1
|
||||
total_rows += len(chunk)
|
||||
logger.debug(f"Chunk {chunk_num}: {len(chunk)} rows (total: {total_rows})")
|
||||
yield chunk
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"Chunked query returned {total_rows} rows in {chunk_num} chunks ({elapsed:.2f}s)")
|
||||
|
||||
def execute_chunked_dict(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[tuple] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
timeout: Optional[int] = None,
|
||||
max_rows: Optional[int] = None,
|
||||
) -> Generator[list[dict], None, None]:
|
||||
"""
|
||||
Execute a query and yield dict results in chunks for memory efficiency.
|
||||
|
||||
Same as execute_chunked but returns rows as dictionaries.
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
params: Optional query parameters
|
||||
chunk_size: Number of rows per chunk (default from config)
|
||||
timeout: Optional query timeout in seconds
|
||||
max_rows: Maximum total rows to return (default from config, 0 for no limit)
|
||||
|
||||
Yields:
|
||||
List of result rows as dictionaries for each chunk
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
effective_timeout = timeout or self._config.timeouts.query_timeout
|
||||
effective_chunk_size = chunk_size or self._config.query.chunk_size
|
||||
effective_max_rows = max_rows if max_rows is not None else self._config.query.max_rows
|
||||
|
||||
with self.get_cursor(dict_cursor=True) as cursor:
|
||||
logger.info(f"Executing chunked dict query (chunk_size={effective_chunk_size}, timeout={effective_timeout}s)")
|
||||
logger.debug(f"Query: {query[:200]}...")
|
||||
|
||||
if effective_timeout > 0:
|
||||
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
|
||||
|
||||
start_time = time.time()
|
||||
cursor.execute(query, params)
|
||||
|
||||
total_rows = 0
|
||||
chunk_num = 0
|
||||
|
||||
while True:
|
||||
# Determine how many rows to fetch this chunk
|
||||
if effective_max_rows > 0:
|
||||
remaining = effective_max_rows - total_rows
|
||||
if remaining <= 0:
|
||||
break
|
||||
fetch_size = min(effective_chunk_size, remaining)
|
||||
else:
|
||||
fetch_size = effective_chunk_size
|
||||
|
||||
chunk = cursor.fetchmany(fetch_size)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
chunk_num += 1
|
||||
total_rows += len(chunk)
|
||||
logger.debug(f"Chunk {chunk_num}: {len(chunk)} rows (total: {total_rows})")
|
||||
yield chunk # type: ignore[misc]
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"Chunked dict query returned {total_rows} rows in {chunk_num} chunks ({elapsed:.2f}s)")
|
||||
|
||||
def execute_with_row_limit(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[tuple] = None,
|
||||
max_rows: Optional[int] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> tuple[list[dict], bool]:
|
||||
"""
|
||||
Execute a query with a row limit and indicate if more rows were available.
|
||||
|
||||
This is useful for pagination or previewing large result sets.
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
params: Optional query parameters
|
||||
max_rows: Maximum rows to return (default from config)
|
||||
timeout: Optional query timeout in seconds
|
||||
|
||||
Returns:
|
||||
Tuple of (results list, has_more bool)
|
||||
- results: List of result rows as dictionaries (up to max_rows)
|
||||
- has_more: True if there were more rows than max_rows
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
effective_timeout = timeout or self._config.timeouts.query_timeout
|
||||
effective_max_rows = max_rows if max_rows is not None else self._config.query.max_rows
|
||||
|
||||
with self.get_cursor(dict_cursor=True) as cursor:
|
||||
logger.info(f"Executing query with limit (max_rows={effective_max_rows}, timeout={effective_timeout}s)")
|
||||
logger.debug(f"Query: {query[:200]}...")
|
||||
|
||||
if effective_timeout > 0:
|
||||
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
|
||||
|
||||
start_time = time.time()
|
||||
cursor.execute(query, params)
|
||||
|
||||
# Fetch one more than max to detect if there are more rows
|
||||
results = cursor.fetchmany(effective_max_rows + 1)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
has_more = len(results) > effective_max_rows
|
||||
if has_more:
|
||||
results = results[:effective_max_rows]
|
||||
|
||||
logger.info(f"Query returned {len(results)} rows (has_more={has_more}) in {elapsed:.2f}s")
|
||||
return results, has_more # type: ignore[return-value]
|
||||
|
||||
def fetch_activity_data(
|
||||
self,
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
provider_codes: Optional[list[str]] = None,
|
||||
max_rows: Optional[int] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Fetch high-cost drug activity data from Snowflake.
|
||||
|
||||
Queries the CDM.Acute__Conmon__PatientLevelDrugs table and returns
|
||||
data in a format compatible with the existing analysis pipeline.
|
||||
|
||||
Args:
|
||||
start_date: Optional start date for filtering (inclusive)
|
||||
end_date: Optional end date for filtering (inclusive)
|
||||
provider_codes: Optional list of provider codes to filter by
|
||||
max_rows: Maximum rows to return (default from config)
|
||||
timeout: Query timeout in seconds (default from config)
|
||||
|
||||
Returns:
|
||||
List of dictionaries with keys matching expected DataFrame columns:
|
||||
- PseudoNHSNoLinked: Pseudonymised NHS number (for UPID creation)
|
||||
- Provider Code: NHS provider code
|
||||
- PersonKey: Local patient identifier
|
||||
- Drug Name: Raw drug name
|
||||
- Intervention Date: Date of intervention
|
||||
- Price Actual: Cost of intervention
|
||||
- OrganisationName: Provider organisation name
|
||||
- Treatment Function Code: NHS treatment function code
|
||||
- Additional Detail 1-5: Additional details for directory identification
|
||||
|
||||
Raises:
|
||||
SnowflakeConnectionError: If not connected or query fails
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
# Build the query
|
||||
table_name = 'DATA_HUB.CDM."Acute__Conmon__PatientLevelDrugs"'
|
||||
|
||||
query = f'''
|
||||
SELECT
|
||||
"PseudoNHSNoLinked",
|
||||
"ProviderCode" AS "Provider Code",
|
||||
"LocalPatientID" AS "PersonKey",
|
||||
"DrugName" AS "Drug Name",
|
||||
"InterventionDate" AS "Intervention Date",
|
||||
"PriceActual" AS "Price Actual",
|
||||
"ProviderName" AS "OrganisationName",
|
||||
"TreatmentFunctionCode" AS "Treatment Function Code",
|
||||
"TreatmentFunctionDesc" AS "Treatment Function Desc",
|
||||
"AdditionalDetail1" AS "Additional Detail 1",
|
||||
"AdditionalDescription1" AS "Additional Description 1",
|
||||
"AdditionalDetail2" AS "Additional Detail 2",
|
||||
"AdditionalDescription2" AS "Additional Description 2",
|
||||
"AdditionalDetail3" AS "Additional Detail 3",
|
||||
"AdditionalDescription3" AS "Additional Description 3",
|
||||
"AdditionalDetail4" AS "Additional Detail 4",
|
||||
"AdditionalDescription4" AS "Additional Description 4",
|
||||
"AdditionalDetail5" AS "Additional Detail 5",
|
||||
"AdditionalDescription5" AS "Additional Description 5"
|
||||
FROM {table_name}
|
||||
WHERE 1=1
|
||||
'''
|
||||
|
||||
params = []
|
||||
|
||||
# Add date filters
|
||||
if start_date:
|
||||
query += ' AND "InterventionDate" >= %s'
|
||||
params.append(start_date.isoformat())
|
||||
if end_date:
|
||||
query += ' AND "InterventionDate" <= %s'
|
||||
params.append(end_date.isoformat())
|
||||
|
||||
# Add provider filter
|
||||
if provider_codes:
|
||||
placeholders = ", ".join(["%s"] * len(provider_codes))
|
||||
query += f' AND "ProviderCode" IN ({placeholders})'
|
||||
params.extend(provider_codes)
|
||||
|
||||
# Add ordering for consistent results
|
||||
query += ' ORDER BY "InterventionDate", "ProviderCode", "PseudoNHSNoLinked"'
|
||||
|
||||
logger.info(f"Fetching activity data from Snowflake")
|
||||
if start_date:
|
||||
logger.info(f" Date range: {start_date} to {end_date or 'now'}")
|
||||
if provider_codes:
|
||||
logger.info(f" Providers: {provider_codes}")
|
||||
|
||||
effective_max_rows = max_rows if max_rows is not None else self._config.query.max_rows
|
||||
effective_timeout = timeout or self._config.timeouts.query_timeout
|
||||
|
||||
# Execute with chunked results for large datasets
|
||||
all_results = []
|
||||
total_rows = 0
|
||||
|
||||
for chunk in self.execute_chunked_dict(
|
||||
query,
|
||||
params=tuple(params) if params else None,
|
||||
timeout=effective_timeout,
|
||||
max_rows=effective_max_rows,
|
||||
):
|
||||
all_results.extend(chunk)
|
||||
total_rows += len(chunk)
|
||||
logger.debug(f"Fetched {total_rows} rows so far...")
|
||||
|
||||
logger.info(f"Fetched {len(all_results)} activity records from Snowflake")
|
||||
return all_results
|
||||
|
||||
def test_connection(self) -> tuple[bool, str]:
|
||||
"""
|
||||
Test the Snowflake connection.
|
||||
|
||||
Returns:
|
||||
Tuple of (success: bool, message: str)
|
||||
"""
|
||||
try:
|
||||
self._check_availability()
|
||||
except SnowflakeNotAvailableError as e:
|
||||
return False, str(e)
|
||||
|
||||
try:
|
||||
self._check_configured()
|
||||
except SnowflakeNotConfiguredError as e:
|
||||
return False, str(e)
|
||||
|
||||
try:
|
||||
self.connect()
|
||||
user = self._get_current_user()
|
||||
role = self._get_current_role()
|
||||
return True, f"Connected as {user} with role {role}"
|
||||
except Exception as e:
|
||||
return False, f"Connection failed: {e}"
|
||||
|
||||
def __enter__(self) -> "SnowflakeConnector":
|
||||
"""Context manager entry."""
|
||||
self.connect()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
"""Context manager exit."""
|
||||
self.close()
|
||||
|
||||
|
||||
# Module-level singleton for convenience
|
||||
_default_connector: Optional[SnowflakeConnector] = None
|
||||
|
||||
|
||||
def get_connector(config: Optional[SnowflakeConfig] = None) -> SnowflakeConnector:
|
||||
"""
|
||||
Get a Snowflake connector (creates singleton on first call).
|
||||
|
||||
Args:
|
||||
config: Optional configuration. If provided, creates new connector
|
||||
with this config. If None, uses/creates default connector.
|
||||
|
||||
Returns:
|
||||
SnowflakeConnector instance
|
||||
"""
|
||||
global _default_connector
|
||||
|
||||
if config is not None:
|
||||
# Custom config requested, create new connector
|
||||
return SnowflakeConnector(config)
|
||||
|
||||
if _default_connector is None:
|
||||
_default_connector = SnowflakeConnector()
|
||||
|
||||
return _default_connector
|
||||
|
||||
|
||||
def reset_connector() -> None:
|
||||
"""Reset the default connector (closes connection and clears singleton)."""
|
||||
global _default_connector
|
||||
|
||||
if _default_connector is not None:
|
||||
_default_connector.close()
|
||||
_default_connector = None
|
||||
|
||||
|
||||
def is_snowflake_available() -> bool:
|
||||
"""Return True if snowflake-connector-python is installed."""
|
||||
return SNOWFLAKE_AVAILABLE
|
||||
|
||||
|
||||
def is_snowflake_configured() -> bool:
|
||||
"""Return True if Snowflake account is configured."""
|
||||
try:
|
||||
config = get_snowflake_config()
|
||||
return config.is_configured
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# Export public API
|
||||
__all__ = [
|
||||
"SnowflakeConnector",
|
||||
"SnowflakeConnectionError",
|
||||
"SnowflakeNotConfiguredError",
|
||||
"SnowflakeNotAvailableError",
|
||||
"ConnectionInfo",
|
||||
"get_connector",
|
||||
"reset_connector",
|
||||
"is_snowflake_available",
|
||||
"is_snowflake_configured",
|
||||
"SNOWFLAKE_AVAILABLE",
|
||||
]
|
||||
Reference in New Issue
Block a user