Initial commit before Ralph loop

This commit is contained in:
Andrew Charlwood
2026-02-04 13:04:29 +00:00
commit fdd33a67af
89 changed files with 20660 additions and 0 deletions
+273
View File
@@ -0,0 +1,273 @@
"""
Data processing module for NHS High-Cost Drug Patient Pathway Analysis Tool.
Contains SQLite database management, data loaders, and Snowflake integration.
Handles the migration from CSV-based storage to SQLite for improved performance.
Submodules:
database: SQLite connection management and schema definitions
loader: Data loading abstractions (CSV, SQLite, Snowflake)
snowflake_connector: Snowflake integration with SSO authentication
"""
from data_processing.database import (
DatabaseConfig,
DatabaseManager,
default_db_config,
default_db_manager,
)
from data_processing.schema import (
# Reference table schemas
REF_DRUG_NAMES_SCHEMA,
REF_ORGANIZATIONS_SCHEMA,
REF_DIRECTORIES_SCHEMA,
REF_DRUG_DIRECTORY_MAP_SCHEMA,
REF_DRUG_INDICATION_CLUSTERS_SCHEMA,
REFERENCE_TABLES_SCHEMA,
# Fact table schemas
FACT_INTERVENTIONS_SCHEMA,
FACT_TABLES_SCHEMA,
# Materialized view schemas
MV_PATIENT_TREATMENT_SUMMARY_SCHEMA,
MATERIALIZED_VIEWS_SCHEMA,
# File tracking schemas
PROCESSED_FILES_SCHEMA,
FILE_TRACKING_SCHEMA,
# Combined schema
ALL_TABLES_SCHEMA,
# Reference table functions
create_reference_tables,
drop_reference_tables,
get_reference_table_counts,
verify_reference_tables_exist,
# Fact table functions
create_fact_tables,
drop_fact_tables,
get_fact_table_counts,
verify_fact_tables_exist,
# File tracking functions
create_file_tracking_tables,
drop_file_tracking_tables,
get_file_tracking_counts,
verify_file_tracking_tables_exist,
# Combined functions
create_all_tables,
drop_all_tables,
get_all_table_counts,
verify_all_tables_exist,
)
# Reference data migration functions
from data_processing.reference_data import (
MigrationResult,
migrate_drug_names,
get_drug_name_counts,
verify_drug_names_migration,
migrate_organizations,
get_organization_counts,
verify_organizations_migration,
migrate_directories,
get_directory_counts,
verify_directories_migration,
migrate_drug_directory_map,
get_drug_directory_map_counts,
verify_drug_directory_map_migration,
migrate_drug_indication_clusters,
get_drug_indication_cluster_counts,
verify_drug_indication_clusters_migration,
)
# Data loader abstractions
from data_processing.loader import (
DataLoader,
FileDataLoader,
SQLiteDataLoader,
LoadResult,
get_loader,
REQUIRED_COLUMNS,
OPTIONAL_COLUMNS,
)
# Patient data migration functions
from data_processing.patient_data import (
PatientDataLoadResult,
load_patient_data,
get_patient_data_stats,
list_processed_files,
calculate_file_hash,
# Materialized view functions
MVRefreshResult,
refresh_patient_treatment_summary,
get_patient_summary_stats,
verify_mv_consistency,
)
# Snowflake connector
from data_processing.snowflake_connector import (
SnowflakeConnector,
SnowflakeConnectionError,
SnowflakeNotConfiguredError,
SnowflakeNotAvailableError,
ConnectionInfo,
get_connector,
reset_connector,
is_snowflake_available,
is_snowflake_configured,
SNOWFLAKE_AVAILABLE,
)
# Query result caching
from data_processing.cache import (
QueryCache,
CacheEntry,
CacheStats,
get_cache,
reset_cache,
is_cache_enabled,
)
# Data source management with fallback chain
from data_processing.data_source import (
DataSourceType,
DataSourceResult,
SourceStatus,
DataSourceManager,
get_data_source_manager,
get_data,
reset_data_source_manager,
)
# Diagnosis lookup (GP diagnosis validation)
from data_processing.diagnosis_lookup import (
ClusterSnomedCodes,
IndicationValidationResult,
DrugIndicationMatchRate,
get_drug_clusters,
get_drug_cluster_ids,
get_cluster_snomed_codes,
patient_has_indication,
validate_indication,
get_indication_match_rate,
batch_validate_indications,
get_available_clusters,
)
__all__ = [
# Database management
"DatabaseConfig",
"DatabaseManager",
"default_db_config",
"default_db_manager",
# Reference table schemas
"REF_DRUG_NAMES_SCHEMA",
"REF_ORGANIZATIONS_SCHEMA",
"REF_DIRECTORIES_SCHEMA",
"REF_DRUG_DIRECTORY_MAP_SCHEMA",
"REF_DRUG_INDICATION_CLUSTERS_SCHEMA",
"REFERENCE_TABLES_SCHEMA",
# Fact table schemas
"FACT_INTERVENTIONS_SCHEMA",
"FACT_TABLES_SCHEMA",
# Materialized view schemas
"MV_PATIENT_TREATMENT_SUMMARY_SCHEMA",
"MATERIALIZED_VIEWS_SCHEMA",
# File tracking schemas
"PROCESSED_FILES_SCHEMA",
"FILE_TRACKING_SCHEMA",
# Combined schema
"ALL_TABLES_SCHEMA",
# Reference table functions
"create_reference_tables",
"drop_reference_tables",
"get_reference_table_counts",
"verify_reference_tables_exist",
# Fact table functions
"create_fact_tables",
"drop_fact_tables",
"get_fact_table_counts",
"verify_fact_tables_exist",
# File tracking functions
"create_file_tracking_tables",
"drop_file_tracking_tables",
"get_file_tracking_counts",
"verify_file_tracking_tables_exist",
# Combined functions
"create_all_tables",
"drop_all_tables",
"get_all_table_counts",
"verify_all_tables_exist",
# Reference data migration
"MigrationResult",
"migrate_drug_names",
"get_drug_name_counts",
"verify_drug_names_migration",
"migrate_organizations",
"get_organization_counts",
"verify_organizations_migration",
"migrate_directories",
"get_directory_counts",
"verify_directories_migration",
"migrate_drug_directory_map",
"get_drug_directory_map_counts",
"verify_drug_directory_map_migration",
"migrate_drug_indication_clusters",
"get_drug_indication_cluster_counts",
"verify_drug_indication_clusters_migration",
# Data loader abstractions
"DataLoader",
"FileDataLoader",
"SQLiteDataLoader",
"LoadResult",
"get_loader",
"REQUIRED_COLUMNS",
"OPTIONAL_COLUMNS",
# Patient data migration
"PatientDataLoadResult",
"load_patient_data",
"get_patient_data_stats",
"list_processed_files",
"calculate_file_hash",
# Materialized view functions
"MVRefreshResult",
"refresh_patient_treatment_summary",
"get_patient_summary_stats",
"verify_mv_consistency",
# Snowflake connector
"SnowflakeConnector",
"SnowflakeConnectionError",
"SnowflakeNotConfiguredError",
"SnowflakeNotAvailableError",
"ConnectionInfo",
"get_connector",
"reset_connector",
"is_snowflake_available",
"is_snowflake_configured",
"SNOWFLAKE_AVAILABLE",
# Query result caching
"QueryCache",
"CacheEntry",
"CacheStats",
"get_cache",
"reset_cache",
"is_cache_enabled",
# Data source management with fallback chain
"DataSourceType",
"DataSourceResult",
"SourceStatus",
"DataSourceManager",
"get_data_source_manager",
"get_data",
"reset_data_source_manager",
# Diagnosis lookup
"ClusterSnomedCodes",
"IndicationValidationResult",
"DrugIndicationMatchRate",
"get_drug_clusters",
"get_drug_cluster_ids",
"get_cluster_snomed_codes",
"patient_has_indication",
"validate_indication",
"get_indication_match_rate",
"batch_validate_indications",
"get_available_clusters",
]
+553
View File
@@ -0,0 +1,553 @@
"""
Query result caching module for NHS Patient Pathway Analysis.
Provides file-based caching for Snowflake query results with TTL-based invalidation.
Supports different TTLs for historical data vs data including the current date.
Cache keys are generated from query hashes. Results are stored as compressed JSON.
Usage:
from data_processing.cache import QueryCache, get_cache
cache = get_cache()
# Check for cached result
result = cache.get(query, params)
if result is None:
# Execute query and cache result
result = execute_query(query, params)
cache.set(query, params, result, includes_current_data=False)
"""
from dataclasses import dataclass
from datetime import datetime, date
from pathlib import Path
from typing import Any, Optional
import gzip
import hashlib
import json
import os
import time
from config import get_snowflake_config, CacheConfig
from core.logging_config import get_logger
logger = get_logger(__name__)
@dataclass
class CacheEntry:
"""Metadata for a cached query result."""
cache_key: str
query_hash: str
created_at: datetime
expires_at: datetime
includes_current_data: bool
row_count: int
file_size_bytes: int
file_path: Path
@dataclass
class CacheStats:
"""Statistics about the cache."""
enabled: bool
cache_dir: Path
total_entries: int
total_size_mb: float
max_size_mb: int
oldest_entry: Optional[datetime]
newest_entry: Optional[datetime]
hit_count: int
miss_count: int
class QueryCache:
"""
File-based cache for Snowflake query results.
Results are stored as gzipped JSON files with TTL-based expiration.
Supports different TTLs for historical vs current data.
Attributes:
config: CacheConfig with cache settings
cache_dir: Path to cache directory
"""
def __init__(self, config: Optional[CacheConfig] = None, base_path: Optional[Path] = None):
"""
Initialize the query cache.
Args:
config: Optional CacheConfig. If not provided, loads from snowflake.toml
base_path: Base path for relative cache directory. Defaults to cwd.
"""
if config is None:
sf_config = get_snowflake_config()
config = sf_config.cache
self._config = config
self._base_path = base_path or Path.cwd()
# Resolve cache directory
cache_dir = Path(config.directory)
if not cache_dir.is_absolute():
cache_dir = self._base_path / cache_dir
self._cache_dir = cache_dir
# Stats tracking (in-memory only, reset on restart)
self._hit_count = 0
self._miss_count = 0
# Ensure cache directory exists if enabled
if self._config.enabled:
self._cache_dir.mkdir(parents=True, exist_ok=True)
@property
def config(self) -> CacheConfig:
"""Return the cache configuration."""
return self._config
@property
def cache_dir(self) -> Path:
"""Return the cache directory path."""
return self._cache_dir
@property
def is_enabled(self) -> bool:
"""Return True if caching is enabled."""
return self._config.enabled
def _generate_cache_key(self, query: str, params: Optional[tuple] = None) -> str:
"""
Generate a cache key from query and parameters.
Uses SHA256 hash of query + params to create unique key.
"""
# Normalize query (strip whitespace, lowercase)
normalized_query = " ".join(query.lower().split())
# Combine query and params
key_content = normalized_query
if params:
key_content += "|" + "|".join(str(p) for p in params)
# Hash to create key
hash_obj = hashlib.sha256(key_content.encode("utf-8"))
return hash_obj.hexdigest()[:32] # Use first 32 chars for readability
def _get_cache_file_path(self, cache_key: str) -> Path:
"""Get the file path for a cache entry."""
return self._cache_dir / f"{cache_key}.json.gz"
def _get_meta_file_path(self, cache_key: str) -> Path:
"""Get the metadata file path for a cache entry."""
return self._cache_dir / f"{cache_key}.meta.json"
def _is_expired(self, meta: dict) -> bool:
"""Check if a cache entry is expired based on its metadata."""
expires_at = datetime.fromisoformat(meta["expires_at"])
return datetime.now() > expires_at
def get(
self,
query: str,
params: Optional[tuple] = None,
check_expiry: bool = True
) -> Optional[list[dict]]:
"""
Get a cached query result.
Args:
query: SQL query string
params: Optional query parameters
check_expiry: If True, returns None for expired entries
Returns:
Cached result as list of dicts, or None if not cached/expired
"""
if not self.is_enabled:
self._miss_count += 1
return None
cache_key = self._generate_cache_key(query, params)
cache_file = self._get_cache_file_path(cache_key)
meta_file = self._get_meta_file_path(cache_key)
# Check if files exist
if not cache_file.exists() or not meta_file.exists():
self._miss_count += 1
logger.debug(f"Cache miss (not found): {cache_key}")
return None
# Load and check metadata
try:
with open(meta_file, "r", encoding="utf-8") as f:
meta = json.load(f)
if check_expiry and self._is_expired(meta):
self._miss_count += 1
logger.debug(f"Cache miss (expired): {cache_key}")
return None
# Load cached data
with gzip.open(cache_file, "rt", encoding="utf-8") as f:
data = json.load(f)
self._hit_count += 1
logger.info(f"Cache hit: {cache_key} ({meta['row_count']} rows)")
return data
except (json.JSONDecodeError, KeyError, OSError) as e:
logger.warning(f"Cache read error for {cache_key}: {e}")
self._miss_count += 1
# Clean up corrupted entry
self._delete_entry(cache_key)
return None
def set(
self,
query: str,
params: Optional[tuple],
data: list[dict],
includes_current_data: bool = False,
custom_ttl_seconds: Optional[int] = None
) -> Optional[CacheEntry]:
"""
Cache a query result.
Args:
query: SQL query string
params: Optional query parameters
data: Query result as list of dicts
includes_current_data: If True, uses shorter TTL for current data
custom_ttl_seconds: Optional custom TTL (overrides config)
Returns:
CacheEntry with metadata, or None if caching disabled/failed
"""
if not self.is_enabled:
return None
cache_key = self._generate_cache_key(query, params)
cache_file = self._get_cache_file_path(cache_key)
meta_file = self._get_meta_file_path(cache_key)
# Determine TTL
if custom_ttl_seconds is not None:
ttl = custom_ttl_seconds
elif includes_current_data:
ttl = self._config.ttl_current_data_seconds
else:
ttl = self._config.ttl_seconds
now = datetime.now()
expires_at = datetime.fromtimestamp(now.timestamp() + ttl)
try:
# Write compressed data
with gzip.open(cache_file, "wt", encoding="utf-8", compresslevel=6) as f:
json.dump(data, f, default=str)
file_size = cache_file.stat().st_size
# Write metadata
meta = {
"cache_key": cache_key,
"query_hash": hashlib.sha256(query.encode()).hexdigest()[:16],
"created_at": now.isoformat(),
"expires_at": expires_at.isoformat(),
"includes_current_data": includes_current_data,
"row_count": len(data),
"file_size_bytes": file_size,
"ttl_seconds": ttl,
}
with open(meta_file, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
logger.info(f"Cached {len(data)} rows as {cache_key} (expires in {ttl}s)")
# Check if we need to enforce size limit
self._enforce_size_limit()
return CacheEntry(
cache_key=cache_key,
query_hash=str(meta["query_hash"]),
created_at=now,
expires_at=expires_at,
includes_current_data=includes_current_data,
row_count=len(data),
file_size_bytes=file_size,
file_path=cache_file,
)
except (OSError, TypeError) as e:
logger.error(f"Failed to cache result: {e}")
return None
def invalidate(self, query: str, params: Optional[tuple] = None) -> bool:
"""
Invalidate a specific cache entry.
Args:
query: SQL query string
params: Optional query parameters
Returns:
True if entry was deleted, False if not found
"""
cache_key = self._generate_cache_key(query, params)
return self._delete_entry(cache_key)
def _delete_entry(self, cache_key: str) -> bool:
"""Delete a cache entry by key."""
cache_file = self._get_cache_file_path(cache_key)
meta_file = self._get_meta_file_path(cache_key)
deleted = False
if cache_file.exists():
cache_file.unlink()
deleted = True
if meta_file.exists():
meta_file.unlink()
deleted = True
if deleted:
logger.debug(f"Deleted cache entry: {cache_key}")
return deleted
def clear(self) -> int:
"""
Clear all cache entries.
Returns:
Number of entries deleted
"""
if not self._cache_dir.exists():
return 0
count = 0
for file in self._cache_dir.glob("*.json*"):
try:
file.unlink()
count += 1
except OSError as e:
logger.warning(f"Failed to delete {file}: {e}")
# Reset stats
self._hit_count = 0
self._miss_count = 0
logger.info(f"Cleared {count} cache files")
return count // 2 # Divide by 2 since we have .json.gz and .meta.json
def clear_expired(self) -> int:
"""
Remove expired cache entries.
Returns:
Number of expired entries deleted
"""
if not self._cache_dir.exists():
return 0
count = 0
for meta_file in self._cache_dir.glob("*.meta.json"):
try:
with open(meta_file, "r", encoding="utf-8") as f:
meta = json.load(f)
if self._is_expired(meta):
cache_key = meta_file.stem.replace(".meta", "")
self._delete_entry(cache_key)
count += 1
except (OSError, json.JSONDecodeError):
# Delete corrupted metadata files
cache_key = meta_file.stem.replace(".meta", "")
self._delete_entry(cache_key)
count += 1
logger.info(f"Cleared {count} expired cache entries")
return count
def _get_total_size_mb(self) -> float:
"""Calculate total cache size in MB."""
if not self._cache_dir.exists():
return 0.0
total_bytes = sum(
f.stat().st_size
for f in self._cache_dir.glob("*")
if f.is_file()
)
return total_bytes / (1024 * 1024)
def _enforce_size_limit(self) -> int:
"""
Enforce cache size limit by removing oldest entries.
Returns:
Number of entries removed
"""
max_size_mb = self._config.max_size_mb
current_size_mb = self._get_total_size_mb()
if current_size_mb <= max_size_mb:
return 0
# Get all entries sorted by creation time
entries = []
for meta_file in self._cache_dir.glob("*.meta.json"):
try:
with open(meta_file, "r", encoding="utf-8") as f:
meta = json.load(f)
entries.append((
meta_file.stem.replace(".meta", ""),
datetime.fromisoformat(meta["created_at"]),
meta.get("file_size_bytes", 0)
))
except (OSError, json.JSONDecodeError, KeyError):
# Clean up corrupted entry
cache_key = meta_file.stem.replace(".meta", "")
self._delete_entry(cache_key)
# Sort by creation time (oldest first)
entries.sort(key=lambda x: x[1])
# Remove oldest entries until under limit
removed = 0
size_to_remove_bytes = (current_size_mb - max_size_mb * 0.9) * 1024 * 1024 # Target 90% of limit
removed_bytes = 0
for cache_key, created_at, file_size in entries:
if removed_bytes >= size_to_remove_bytes:
break
self._delete_entry(cache_key)
removed_bytes += file_size
removed += 1
logger.info(f"Removed {removed} cache entries to enforce size limit")
return removed
def get_stats(self) -> CacheStats:
"""Get cache statistics."""
if not self._cache_dir.exists():
return CacheStats(
enabled=self.is_enabled,
cache_dir=self._cache_dir,
total_entries=0,
total_size_mb=0.0,
max_size_mb=self._config.max_size_mb,
oldest_entry=None,
newest_entry=None,
hit_count=self._hit_count,
miss_count=self._miss_count,
)
entries = []
for meta_file in self._cache_dir.glob("*.meta.json"):
try:
with open(meta_file, "r", encoding="utf-8") as f:
meta = json.load(f)
entries.append(datetime.fromisoformat(meta["created_at"]))
except (OSError, json.JSONDecodeError, KeyError):
pass
oldest = min(entries) if entries else None
newest = max(entries) if entries else None
return CacheStats(
enabled=self.is_enabled,
cache_dir=self._cache_dir,
total_entries=len(entries),
total_size_mb=self._get_total_size_mb(),
max_size_mb=self._config.max_size_mb,
oldest_entry=oldest,
newest_entry=newest,
hit_count=self._hit_count,
miss_count=self._miss_count,
)
def list_entries(self) -> list[CacheEntry]:
"""List all cache entries with metadata."""
if not self._cache_dir.exists():
return []
entries = []
for meta_file in self._cache_dir.glob("*.meta.json"):
try:
with open(meta_file, "r", encoding="utf-8") as f:
meta = json.load(f)
cache_key = meta["cache_key"]
entries.append(CacheEntry(
cache_key=cache_key,
query_hash=meta.get("query_hash", ""),
created_at=datetime.fromisoformat(meta["created_at"]),
expires_at=datetime.fromisoformat(meta["expires_at"]),
includes_current_data=meta.get("includes_current_data", False),
row_count=meta.get("row_count", 0),
file_size_bytes=meta.get("file_size_bytes", 0),
file_path=self._get_cache_file_path(cache_key),
))
except (OSError, json.JSONDecodeError, KeyError):
pass
# Sort by creation time (newest first)
entries.sort(key=lambda x: x.created_at, reverse=True)
return entries
# Module-level singleton
_default_cache: Optional[QueryCache] = None
def get_cache(config: Optional[CacheConfig] = None) -> QueryCache:
"""
Get a QueryCache instance (creates singleton on first call).
Args:
config: Optional CacheConfig. If provided, creates new cache with
this config. If None, uses/creates default cache.
Returns:
QueryCache instance
"""
global _default_cache
if config is not None:
# Custom config requested, create new cache
return QueryCache(config)
if _default_cache is None:
_default_cache = QueryCache()
return _default_cache
def reset_cache() -> None:
"""Reset the default cache singleton."""
global _default_cache
_default_cache = None
def is_cache_enabled() -> bool:
"""Return True if caching is enabled in configuration."""
config = get_snowflake_config()
return config.cache.enabled
# Export public API
__all__ = [
"QueryCache",
"CacheEntry",
"CacheStats",
"get_cache",
"reset_cache",
"is_cache_enabled",
]
+968
View File
@@ -0,0 +1,968 @@
"""
Unified data access layer with fallback chain for NHS Patient Pathway Analysis.
Provides a high-level interface that automatically selects the best available data source:
1. Cache - Returns cached results if valid and not expired
2. Snowflake - Queries Snowflake warehouse if configured and connected
3. Local - Falls back to SQLite database or CSV/Parquet files
The fallback chain handles connection errors, missing configurations, and
unavailable services gracefully, always attempting to provide data from
some source.
Usage:
from data_processing.data_source import DataSourceManager, get_data
# Simple usage with automatic source selection
result = get_data(
start_date=date(2024, 1, 1),
end_date=date(2024, 12, 31),
trusts=["TRUST A", "TRUST B"],
)
# Or with explicit source preference
manager = DataSourceManager()
result = manager.get_data(
start_date=date(2024, 1, 1),
end_date=date(2024, 12, 31),
preferred_source="snowflake",
)
"""
from dataclasses import dataclass, field
from datetime import date, datetime
from enum import Enum
from pathlib import Path
from typing import Optional, Callable
import pandas as pd
from core.logging_config import get_logger
logger = get_logger(__name__)
class DataSourceType(Enum):
"""Enumeration of available data sources."""
CACHE = "cache"
SNOWFLAKE = "snowflake"
SQLITE = "sqlite"
FILE = "file"
@dataclass
class DataSourceResult:
"""Result from data source query.
Attributes:
df: The loaded DataFrame with patient intervention data
source_type: Which data source was used
source_detail: Additional details about the source (e.g., file path, query hash)
row_count: Number of rows returned
cached: Whether the result came from cache
from_fallback: Whether a fallback source was used
load_time_seconds: Time taken to load data
warnings: Any warnings generated during loading
"""
df: pd.DataFrame
source_type: DataSourceType
source_detail: str = ""
row_count: int = 0
cached: bool = False
from_fallback: bool = False
load_time_seconds: float = 0.0
warnings: list[str] = field(default_factory=list)
def __post_init__(self):
if self.row_count == 0 and self.df is not None:
self.row_count = len(self.df)
@dataclass
class SourceStatus:
"""Status of a data source.
Attributes:
source_type: The type of data source
available: Whether the source is available
configured: Whether the source is properly configured
message: Status message explaining the state
last_checked: When the status was last checked
"""
source_type: DataSourceType
available: bool = False
configured: bool = False
message: str = ""
last_checked: Optional[datetime] = None
class DataSourceManager:
"""
Manages data access with automatic fallback between sources.
The manager attempts to retrieve data from sources in order of preference:
1. Cache (if enabled and has valid cached data)
2. Snowflake (if configured and connected)
3. SQLite (if database exists with data)
4. Local files (CSV/Parquet)
Attributes:
cache_enabled: Whether to use caching
local_file_path: Path to local CSV/Parquet file (optional fallback)
sqlite_db_path: Path to SQLite database (optional)
Example:
manager = DataSourceManager()
# Check what sources are available
status = manager.check_all_sources()
for s in status:
print(f"{s.source_type.value}: {s.message}")
# Get data with automatic fallback
result = manager.get_data(
start_date=date(2024, 1, 1),
end_date=date(2024, 6, 30),
)
print(f"Got {result.row_count} rows from {result.source_type.value}")
"""
def __init__(
self,
cache_enabled: bool = True,
local_file_path: Optional[Path | str] = None,
sqlite_db_path: Optional[Path | str] = None,
):
"""
Initialize the data source manager.
Args:
cache_enabled: Whether to check cache before querying (default True)
local_file_path: Path to local CSV/Parquet file for file fallback
sqlite_db_path: Path to SQLite database (uses default if None)
"""
self._cache_enabled = cache_enabled
self._local_file_path = Path(local_file_path) if local_file_path else None
self._sqlite_db_path = Path(sqlite_db_path) if sqlite_db_path else None
self._source_status: dict[DataSourceType, SourceStatus] = {}
@property
def cache_enabled(self) -> bool:
"""Return whether caching is enabled."""
return self._cache_enabled
@cache_enabled.setter
def cache_enabled(self, value: bool):
"""Set whether caching is enabled."""
self._cache_enabled = value
def _check_cache_status(self) -> SourceStatus:
"""Check if cache is available."""
try:
from data_processing.cache import is_cache_enabled, get_cache
if not is_cache_enabled():
return SourceStatus(
source_type=DataSourceType.CACHE,
available=False,
configured=False,
message="Cache disabled in configuration",
last_checked=datetime.now(),
)
cache = get_cache()
stats = cache.get_stats()
return SourceStatus(
source_type=DataSourceType.CACHE,
available=True,
configured=True,
message=f"Cache enabled ({stats.total_entries} entries, {stats.total_size_mb:.1f}MB)",
last_checked=datetime.now(),
)
except Exception as e:
return SourceStatus(
source_type=DataSourceType.CACHE,
available=False,
configured=False,
message=f"Cache error: {e}",
last_checked=datetime.now(),
)
def _check_snowflake_status(self) -> SourceStatus:
"""Check if Snowflake is available and configured."""
try:
from data_processing.snowflake_connector import (
is_snowflake_available,
is_snowflake_configured,
)
if not is_snowflake_available():
return SourceStatus(
source_type=DataSourceType.SNOWFLAKE,
available=False,
configured=False,
message="snowflake-connector-python not installed",
last_checked=datetime.now(),
)
if not is_snowflake_configured():
return SourceStatus(
source_type=DataSourceType.SNOWFLAKE,
available=True,
configured=False,
message="Snowflake account not configured in config/snowflake.toml",
last_checked=datetime.now(),
)
return SourceStatus(
source_type=DataSourceType.SNOWFLAKE,
available=True,
configured=True,
message="Snowflake configured and ready",
last_checked=datetime.now(),
)
except Exception as e:
return SourceStatus(
source_type=DataSourceType.SNOWFLAKE,
available=False,
configured=False,
message=f"Snowflake error: {e}",
last_checked=datetime.now(),
)
def _check_sqlite_status(self) -> SourceStatus:
"""Check if SQLite database is available with data."""
try:
from data_processing.database import default_db_manager, default_db_config
db_path = self._sqlite_db_path or Path(default_db_config.db_path)
if not db_path.exists():
return SourceStatus(
source_type=DataSourceType.SQLITE,
available=False,
configured=True,
message=f"Database not found: {db_path}",
last_checked=datetime.now(),
)
from data_processing.database import DatabaseManager, DatabaseConfig
config = DatabaseConfig(db_path=db_path)
manager = DatabaseManager(config)
if not manager.table_exists("fact_interventions"):
return SourceStatus(
source_type=DataSourceType.SQLITE,
available=False,
configured=True,
message="fact_interventions table not found",
last_checked=datetime.now(),
)
count = manager.get_table_count("fact_interventions")
if count == 0:
return SourceStatus(
source_type=DataSourceType.SQLITE,
available=False,
configured=True,
message="fact_interventions table is empty",
last_checked=datetime.now(),
)
return SourceStatus(
source_type=DataSourceType.SQLITE,
available=True,
configured=True,
message=f"SQLite database ready ({count:,} rows)",
last_checked=datetime.now(),
)
except Exception as e:
return SourceStatus(
source_type=DataSourceType.SQLITE,
available=False,
configured=False,
message=f"SQLite error: {e}",
last_checked=datetime.now(),
)
def _check_file_status(self) -> SourceStatus:
"""Check if local file is available."""
if self._local_file_path is None:
return SourceStatus(
source_type=DataSourceType.FILE,
available=False,
configured=False,
message="No local file path configured",
last_checked=datetime.now(),
)
if not self._local_file_path.exists():
return SourceStatus(
source_type=DataSourceType.FILE,
available=False,
configured=True,
message=f"File not found: {self._local_file_path}",
last_checked=datetime.now(),
)
size_mb = self._local_file_path.stat().st_size / (1024 * 1024)
return SourceStatus(
source_type=DataSourceType.FILE,
available=True,
configured=True,
message=f"Local file ready: {self._local_file_path.name} ({size_mb:.1f}MB)",
last_checked=datetime.now(),
)
def check_source_status(self, source_type: DataSourceType) -> SourceStatus:
"""
Check the status of a specific data source.
Args:
source_type: The type of source to check
Returns:
SourceStatus with current availability information
"""
if source_type == DataSourceType.CACHE:
return self._check_cache_status()
elif source_type == DataSourceType.SNOWFLAKE:
return self._check_snowflake_status()
elif source_type == DataSourceType.SQLITE:
return self._check_sqlite_status()
elif source_type == DataSourceType.FILE:
return self._check_file_status()
else:
return SourceStatus(
source_type=source_type,
available=False,
configured=False,
message=f"Unknown source type: {source_type}",
last_checked=datetime.now(),
)
def check_all_sources(self) -> list[SourceStatus]:
"""
Check the status of all data sources.
Returns:
List of SourceStatus for each source type
"""
statuses = []
for source_type in DataSourceType:
status = self.check_source_status(source_type)
self._source_status[source_type] = status
statuses.append(status)
return statuses
def _build_cache_key_params(
self,
start_date: Optional[date],
end_date: Optional[date],
trusts: Optional[list[str]],
drugs: Optional[list[str]],
directories: Optional[list[str]],
) -> tuple[str, tuple]:
"""Build a cache-compatible query string and params for the filter criteria."""
# Create a canonical representation for caching
query_parts = ["SELECT * FROM activity_data"]
params = []
conditions = []
if start_date:
conditions.append("start_date >= ?")
params.append(str(start_date))
if end_date:
conditions.append("end_date <= ?")
params.append(str(end_date))
if trusts:
placeholders = ",".join(["?"] * len(trusts))
conditions.append(f"trust IN ({placeholders})")
params.extend(sorted(trusts))
if drugs:
placeholders = ",".join(["?"] * len(drugs))
conditions.append(f"drug IN ({placeholders})")
params.extend(sorted(drugs))
if directories:
placeholders = ",".join(["?"] * len(directories))
conditions.append(f"directory IN ({placeholders})")
params.extend(sorted(directories))
if conditions:
query_parts.append("WHERE " + " AND ".join(conditions))
query = " ".join(query_parts)
return query, tuple(params)
def _try_cache(
self,
start_date: Optional[date],
end_date: Optional[date],
trusts: Optional[list[str]],
drugs: Optional[list[str]],
directories: Optional[list[str]],
) -> Optional[DataSourceResult]:
"""Try to get data from cache."""
if not self._cache_enabled:
return None
try:
from data_processing.cache import get_cache
cache = get_cache()
if not cache.is_enabled:
return None
query, params = self._build_cache_key_params(
start_date, end_date, trusts, drugs, directories
)
cached_data = cache.get(query, params)
if cached_data is None:
logger.debug("Cache miss")
return None
# Convert cached data back to DataFrame
df = pd.DataFrame(cached_data)
# Convert date columns
if 'Intervention Date' in df.columns:
df['Intervention Date'] = pd.to_datetime(df['Intervention Date'])
logger.info(f"Cache hit: {len(df)} rows")
return DataSourceResult(
df=df,
source_type=DataSourceType.CACHE,
source_detail=f"cache_key={query[:50]}...",
row_count=len(df),
cached=True,
from_fallback=False,
)
except Exception as e:
logger.warning(f"Cache lookup failed: {e}")
return None
def _try_snowflake(
self,
start_date: Optional[date],
end_date: Optional[date],
trusts: Optional[list[str]],
drugs: Optional[list[str]],
directories: Optional[list[str]],
progress_callback: Optional[Callable[[int, int], None]] = None,
) -> Optional[DataSourceResult]:
"""Try to get data from Snowflake."""
import time
try:
from data_processing.snowflake_connector import (
is_snowflake_available,
is_snowflake_configured,
get_connector,
SnowflakeConnectionError,
)
if not is_snowflake_available():
logger.debug("Snowflake connector not installed")
return None
if not is_snowflake_configured():
logger.debug("Snowflake not configured")
return None
# Get connector and fetch data
connector = get_connector()
logger.info("Fetching data from Snowflake...")
start_time = time.time()
# Fetch activity data from Snowflake
# Note: provider_codes filter not directly supported yet - would need trust name to code mapping
rows = connector.fetch_activity_data(
start_date=start_date,
end_date=end_date,
provider_codes=None, # TODO: map trust names to provider codes if needed
)
if not rows:
logger.warning("Snowflake returned no data")
return None
# Convert to DataFrame
df = pd.DataFrame(rows)
load_time = time.time() - start_time
logger.info(f"Snowflake loaded {len(df)} rows in {load_time:.2f}s")
# Apply local transformations to match expected format
# (patient_id, drug_names, department_identification)
from tools.data import patient_id, drug_names, department_identification
from core import default_paths
df = patient_id(df)
df = drug_names(df, paths=default_paths)
df = department_identification(df, paths=default_paths)
# Apply additional filters if provided
if trusts and 'OrganisationName' in df.columns:
df = df[df['OrganisationName'].isin(trusts)]
if drugs and 'Drug Name' in df.columns:
df = df[df['Drug Name'].isin(drugs)]
if directories and 'Directory' in df.columns:
df = df[df['Directory'].isin(directories)]
return DataSourceResult(
df=df,
source_type=DataSourceType.SNOWFLAKE,
source_detail="DATA_HUB.CDM.Acute__Conmon__PatientLevelDrugs",
row_count=len(df),
cached=False,
from_fallback=False,
load_time_seconds=load_time,
)
except Exception as e:
logger.warning(f"Snowflake query failed: {e}")
return None
def _try_sqlite(
self,
start_date: Optional[date],
end_date: Optional[date],
trusts: Optional[list[str]],
drugs: Optional[list[str]],
directories: Optional[list[str]],
) -> Optional[DataSourceResult]:
"""Try to get data from SQLite."""
import time
try:
from data_processing.loader import SQLiteDataLoader
# Determine database path
db_path = self._sqlite_db_path
if db_path is None:
from data_processing.database import default_db_config
db_path = Path(default_db_config.db_path)
loader = SQLiteDataLoader(
db_path=db_path,
date_range=(start_date, end_date) if start_date and end_date else None,
trusts=trusts,
drugs=drugs,
directories=directories,
)
# Check if source is valid
is_valid, msg = loader.validate_source()
if not is_valid:
logger.debug(f"SQLite not available: {msg}")
return None
start_time = time.time()
result = loader.load()
load_time = time.time() - start_time
logger.info(f"SQLite loaded {result.row_count} rows in {load_time:.2f}s")
return DataSourceResult(
df=result.df,
source_type=DataSourceType.SQLITE,
source_detail=str(db_path),
row_count=result.row_count,
cached=False,
from_fallback=False,
load_time_seconds=load_time,
)
except Exception as e:
logger.warning(f"SQLite query failed: {e}")
return None
def _try_file(
self,
start_date: Optional[date],
end_date: Optional[date],
trusts: Optional[list[str]],
drugs: Optional[list[str]],
directories: Optional[list[str]],
) -> Optional[DataSourceResult]:
"""Try to get data from local file."""
import time
if self._local_file_path is None:
logger.debug("No local file configured")
return None
try:
from data_processing.loader import FileDataLoader
loader = FileDataLoader(file_path=self._local_file_path)
is_valid, msg = loader.validate_source()
if not is_valid:
logger.debug(f"Local file not available: {msg}")
return None
start_time = time.time()
result = loader.load()
df = result.df
# Apply filters (file loader loads all data, then we filter)
if start_date and 'Intervention Date' in df.columns:
df = df[df['Intervention Date'] >= pd.Timestamp(start_date)]
if end_date and 'Intervention Date' in df.columns:
df = df[df['Intervention Date'] < pd.Timestamp(end_date)]
if trusts and 'OrganisationName' in df.columns:
df = df[df['OrganisationName'].isin(trusts)]
if drugs and 'Drug Name' in df.columns:
df = df[df['Drug Name'].isin(drugs)]
if directories and 'Directory' in df.columns:
df = df[df['Directory'].isin(directories)]
load_time = time.time() - start_time
logger.info(f"File loaded and filtered: {len(df)} rows in {load_time:.2f}s")
return DataSourceResult(
df=df,
source_type=DataSourceType.FILE,
source_detail=str(self._local_file_path),
row_count=len(df),
cached=False,
from_fallback=True,
load_time_seconds=load_time,
)
except Exception as e:
logger.warning(f"File load failed: {e}")
return None
def get_data(
self,
start_date: Optional[date] = None,
end_date: Optional[date] = None,
trusts: Optional[list[str]] = None,
drugs: Optional[list[str]] = None,
directories: Optional[list[str]] = None,
preferred_source: Optional[str] = None,
skip_cache: bool = False,
progress_callback: Optional[Callable[[int, int], None]] = None,
) -> DataSourceResult:
"""
Get patient intervention data from the best available source.
The fallback chain is: Cache → Snowflake → SQLite → File
Args:
start_date: Optional start date for filtering (inclusive)
end_date: Optional end date for filtering (exclusive)
trusts: Optional list of trust names to filter
drugs: Optional list of drug names to filter
directories: Optional list of directories to filter
preferred_source: Optional preferred source ("snowflake", "sqlite", "file")
skip_cache: If True, bypass cache and query source directly
progress_callback: Optional callback(current, total) for progress updates
Returns:
DataSourceResult with the loaded data and metadata
Raises:
ValueError: If no data source is available or all sources fail
"""
import time
start_time = time.time()
warnings = []
# If preferred source specified, try that first
if preferred_source:
preferred = preferred_source.lower()
if preferred == "snowflake":
result = self._try_snowflake(
start_date, end_date, trusts, drugs, directories, progress_callback
)
if result:
result.load_time_seconds = time.time() - start_time
return result
warnings.append("Preferred source 'snowflake' unavailable")
elif preferred == "sqlite":
result = self._try_sqlite(
start_date, end_date, trusts, drugs, directories
)
if result:
result.load_time_seconds = time.time() - start_time
return result
warnings.append("Preferred source 'sqlite' unavailable")
elif preferred == "file":
result = self._try_file(
start_date, end_date, trusts, drugs, directories
)
if result:
result.load_time_seconds = time.time() - start_time
return result
warnings.append("Preferred source 'file' unavailable")
# Standard fallback chain: cache → snowflake → sqlite → file
# 1. Try cache first (unless skipped)
if not skip_cache:
result = self._try_cache(
start_date, end_date, trusts, drugs, directories
)
if result:
result.load_time_seconds = time.time() - start_time
return result
# 2. Try Snowflake
result = self._try_snowflake(
start_date, end_date, trusts, drugs, directories, progress_callback
)
if result:
# Cache the result for future queries
if self._cache_enabled:
self._cache_result(
result.df,
start_date, end_date, trusts, drugs, directories,
includes_current_data=end_date is None or end_date >= date.today()
)
result.load_time_seconds = time.time() - start_time
return result
# 3. Try SQLite
result = self._try_sqlite(
start_date, end_date, trusts, drugs, directories
)
if result:
result.from_fallback = True # Mark as fallback since Snowflake wasn't used
result.load_time_seconds = time.time() - start_time
if warnings:
result.warnings.extend(warnings)
return result
# 4. Try local file
result = self._try_file(
start_date, end_date, trusts, drugs, directories
)
if result:
result.from_fallback = True
result.load_time_seconds = time.time() - start_time
if warnings:
result.warnings.extend(warnings)
return result
# All sources failed
source_status = self.check_all_sources()
status_msg = "; ".join(
f"{s.source_type.value}: {s.message}" for s in source_status
)
raise ValueError(f"No data source available. Status: {status_msg}")
def _cache_result(
self,
df: pd.DataFrame,
start_date: Optional[date],
end_date: Optional[date],
trusts: Optional[list[str]],
drugs: Optional[list[str]],
directories: Optional[list[str]],
includes_current_data: bool = False,
) -> bool:
"""Cache a query result for future use."""
try:
from data_processing.cache import get_cache
cache = get_cache()
if not cache.is_enabled:
return False
query, params = self._build_cache_key_params(
start_date, end_date, trusts, drugs, directories
)
# Convert DataFrame to list of dicts for caching
# Convert datetime columns to strings for JSON serialization
df_copy = df.copy()
for col in df_copy.columns:
if pd.api.types.is_datetime64_any_dtype(df_copy[col]):
df_copy[col] = df_copy[col].astype(str)
data = df_copy.to_dict(orient='records')
entry = cache.set(
query, params, data,
includes_current_data=includes_current_data
)
if entry:
logger.info(f"Cached {len(data)} rows (key={entry.cache_key[:16]}...)")
return True
return False
except Exception as e:
logger.warning(f"Failed to cache result: {e}")
return False
def clear_cache(self) -> int:
"""
Clear all cached data.
Returns:
Number of cache entries cleared
"""
try:
from data_processing.cache import get_cache
cache = get_cache()
return cache.clear()
except Exception as e:
logger.warning(f"Failed to clear cache: {e}")
return 0
def refresh_from_snowflake(
self,
start_date: Optional[date] = None,
end_date: Optional[date] = None,
trusts: Optional[list[str]] = None,
drugs: Optional[list[str]] = None,
directories: Optional[list[str]] = None,
progress_callback: Optional[Callable[[int, int], None]] = None,
) -> DataSourceResult:
"""
Force a refresh from Snowflake, bypassing cache and other sources.
This method specifically queries Snowflake and will fail if Snowflake
is not available or not configured.
Args:
start_date: Optional start date for filtering
end_date: Optional end date for filtering
trusts: Optional list of trust names
drugs: Optional list of drug names
directories: Optional list of directories
progress_callback: Optional progress callback
Returns:
DataSourceResult from Snowflake
Raises:
ValueError: If Snowflake is not available or query fails
"""
from data_processing.snowflake_connector import (
is_snowflake_available,
is_snowflake_configured,
)
if not is_snowflake_available():
raise ValueError("Snowflake connector not installed")
if not is_snowflake_configured():
raise ValueError("Snowflake not configured - edit config/snowflake.toml")
result = self._try_snowflake(
start_date, end_date, trusts, drugs, directories, progress_callback
)
if result is None:
raise ValueError("Snowflake query failed - check logs for details")
# Cache the fresh result
if self._cache_enabled:
self._cache_result(
result.df,
start_date, end_date, trusts, drugs, directories,
includes_current_data=end_date is None or end_date >= date.today()
)
return result
# Module-level singleton and convenience functions
_default_manager: Optional[DataSourceManager] = None
def get_data_source_manager(
cache_enabled: bool = True,
local_file_path: Optional[Path | str] = None,
sqlite_db_path: Optional[Path | str] = None,
) -> DataSourceManager:
"""
Get a DataSourceManager instance.
Args:
cache_enabled: Whether to enable caching
local_file_path: Optional path to local CSV/Parquet file
sqlite_db_path: Optional path to SQLite database
Returns:
DataSourceManager instance
"""
global _default_manager
# If custom paths provided, create a new manager
if local_file_path or sqlite_db_path:
return DataSourceManager(
cache_enabled=cache_enabled,
local_file_path=local_file_path,
sqlite_db_path=sqlite_db_path,
)
# Otherwise use/create singleton
if _default_manager is None:
_default_manager = DataSourceManager(cache_enabled=cache_enabled)
return _default_manager
def get_data(
start_date: Optional[date] = None,
end_date: Optional[date] = None,
trusts: Optional[list[str]] = None,
drugs: Optional[list[str]] = None,
directories: Optional[list[str]] = None,
preferred_source: Optional[str] = None,
skip_cache: bool = False,
) -> DataSourceResult:
"""
Convenience function to get data using the default manager.
Args:
start_date: Optional start date for filtering
end_date: Optional end date for filtering
trusts: Optional list of trust names
drugs: Optional list of drug names
directories: Optional list of directories
preferred_source: Optional preferred source
skip_cache: If True, bypass cache
Returns:
DataSourceResult with loaded data
"""
manager = get_data_source_manager()
return manager.get_data(
start_date=start_date,
end_date=end_date,
trusts=trusts,
drugs=drugs,
directories=directories,
preferred_source=preferred_source,
skip_cache=skip_cache,
)
def reset_data_source_manager() -> None:
"""Reset the default data source manager singleton."""
global _default_manager
_default_manager = None
# Export public API
__all__ = [
"DataSourceType",
"DataSourceResult",
"SourceStatus",
"DataSourceManager",
"get_data_source_manager",
"get_data",
"reset_data_source_manager",
]
+239
View File
@@ -0,0 +1,239 @@
"""
SQLite database connection management for NHS High-Cost Drug Patient Pathway Analysis Tool.
Provides connection management, schema initialization, and common database operations.
Uses context manager pattern for safe resource handling.
"""
import sqlite3
from contextlib import contextmanager
from pathlib import Path
from typing import Optional, Generator, Literal
from core.logging_config import get_logger
logger = get_logger(__name__)
class DatabaseConfig:
"""
Configuration for SQLite database location and connection parameters.
Attributes:
db_path: Path to the SQLite database file
timeout: Connection timeout in seconds (default: 30)
isolation_level: Transaction isolation level (default: None for autocommit)
"""
DEFAULT_DB_NAME = "pathways.db"
def __init__(
self,
db_path: Optional[Path] = None,
data_dir: Optional[Path] = None,
timeout: float = 30.0,
isolation_level: Optional[Literal['DEFERRED', 'EXCLUSIVE', 'IMMEDIATE']] = None
):
"""
Initialize database configuration.
Args:
db_path: Full path to database file. If None, uses data_dir/DEFAULT_DB_NAME.
data_dir: Directory to place database in. Defaults to ./data/
timeout: Connection timeout in seconds.
isolation_level: Transaction isolation level. None = autocommit.
"""
if db_path is not None:
self.db_path = Path(db_path)
elif data_dir is not None:
self.db_path = Path(data_dir) / self.DEFAULT_DB_NAME
else:
self.db_path = Path("./data") / self.DEFAULT_DB_NAME
self.timeout = timeout
self.isolation_level = isolation_level
def validate(self) -> list[str]:
"""
Validate database configuration.
Returns:
List of error messages. Empty list means configuration is valid.
"""
errors = []
# Check parent directory exists
parent_dir = self.db_path.parent
if not parent_dir.exists():
errors.append(f"Database directory does not exist: {parent_dir}")
return errors
class DatabaseManager:
"""
Manages SQLite database connections and operations.
Provides context manager for safe connection handling and methods
for common database operations.
Usage:
db_manager = DatabaseManager()
# Using context manager (recommended)
with db_manager.get_connection() as conn:
cursor = conn.execute("SELECT * FROM ref_drug_names")
results = cursor.fetchall()
# Or get a managed connection for longer operations
conn = db_manager.connect()
try:
# ... do work ...
finally:
conn.close()
"""
def __init__(self, config: Optional[DatabaseConfig] = None):
"""
Initialize the database manager.
Args:
config: Database configuration. If None, uses default configuration.
"""
self.config = config or DatabaseConfig()
self._connection: Optional[sqlite3.Connection] = None
@property
def db_path(self) -> Path:
"""Path to the SQLite database file."""
return self.config.db_path
@property
def exists(self) -> bool:
"""Check if the database file exists."""
return self.db_path.exists()
def connect(self) -> sqlite3.Connection:
"""
Create a new database connection.
Returns:
sqlite3.Connection: New database connection.
Note:
The caller is responsible for closing the connection.
Consider using get_connection() context manager instead.
"""
conn = sqlite3.connect(
str(self.db_path),
timeout=self.config.timeout,
isolation_level=self.config.isolation_level
)
# Enable foreign key support
conn.execute("PRAGMA foreign_keys = ON")
# Return rows as sqlite3.Row for dict-like access
conn.row_factory = sqlite3.Row
return conn
@contextmanager
def get_connection(self) -> Generator[sqlite3.Connection, None, None]:
"""
Context manager for database connections.
Yields:
sqlite3.Connection: Database connection.
Example:
with db_manager.get_connection() as conn:
conn.execute("INSERT INTO table VALUES (?)", (value,))
conn.commit()
"""
conn = self.connect()
try:
yield conn
except Exception:
conn.rollback()
raise
finally:
conn.close()
@contextmanager
def get_transaction(self) -> Generator[sqlite3.Connection, None, None]:
"""
Context manager for transactional operations.
Automatically commits on success, rolls back on exception.
Yields:
sqlite3.Connection: Database connection in transaction mode.
Example:
with db_manager.get_transaction() as conn:
conn.execute("INSERT INTO table VALUES (?)", (value1,))
conn.execute("INSERT INTO other_table VALUES (?)", (value2,))
# Auto-commits if no exception
"""
conn = sqlite3.connect(
str(self.db_path),
timeout=self.config.timeout,
isolation_level="DEFERRED" # Explicit transaction mode
)
conn.execute("PRAGMA foreign_keys = ON")
conn.row_factory = sqlite3.Row
try:
yield conn
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
def execute_script(self, sql_script: str) -> None:
"""
Execute a SQL script (multiple statements).
Args:
sql_script: SQL script containing one or more statements.
"""
with self.get_connection() as conn:
conn.executescript(sql_script)
logger.info("Executed SQL script successfully")
def table_exists(self, table_name: str) -> bool:
"""
Check if a table exists in the database.
Args:
table_name: Name of the table to check.
Returns:
True if the table exists, False otherwise.
"""
with self.get_connection() as conn:
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
(table_name,)
)
return cursor.fetchone() is not None
def get_table_count(self, table_name: str) -> int:
"""
Get the row count for a table.
Args:
table_name: Name of the table.
Returns:
Number of rows in the table.
"""
with self.get_connection() as conn:
# Use parameterized table name via string formatting (safe since we control table_name)
cursor = conn.execute(f"SELECT COUNT(*) FROM {table_name}")
result = cursor.fetchone()
return result[0] if result else 0
# Default instance for application-wide use
default_db_config = DatabaseConfig()
default_db_manager = DatabaseManager(default_db_config)
+581
View File
@@ -0,0 +1,581 @@
"""
Diagnosis lookup module for NHS Patient Pathway Analysis.
Provides functions to validate patient indications by checking GP diagnosis records
against SNOMED cluster codes. Uses the drug-to-cluster mapping from
drug_indication_clusters.csv and queries Snowflake for SNOMED codes and GP records.
Key workflow:
1. Get drug's valid indication clusters from local mapping
2. Get all SNOMED codes for those clusters from Snowflake
3. Check if patient has any of those SNOMED codes in GP records
4. Report indication validation status
IMPORTANT: HCD activity data indication codes are UNRELIABLE. This module uses
GP/Primary Care data (PrimaryCareClinicalCoding) as the authoritative source.
"""
from dataclasses import dataclass, field
from datetime import date, datetime
from pathlib import Path
from typing import Optional, Callable, Any, cast
import csv
from core.logging_config import get_logger
from data_processing.database import DatabaseManager, default_db_manager
from data_processing.snowflake_connector import (
SnowflakeConnector,
get_connector,
is_snowflake_available,
is_snowflake_configured,
SNOWFLAKE_AVAILABLE,
)
from data_processing.cache import get_cache, is_cache_enabled
logger = get_logger(__name__)
@dataclass
class ClusterSnomedCodes:
"""SNOMED codes for a clinical coding cluster."""
cluster_id: str
cluster_description: str
snomed_codes: list[str] = field(default_factory=list)
snomed_descriptions: dict[str, str] = field(default_factory=dict)
@property
def code_count(self) -> int:
return len(self.snomed_codes)
@dataclass
class IndicationValidationResult:
"""Result of validating a patient's indication for a drug."""
patient_pseudonym: str
drug_name: str
has_valid_indication: bool
matched_cluster_id: Optional[str] = None
matched_snomed_code: Optional[str] = None
matched_snomed_description: Optional[str] = None
checked_clusters: list[str] = field(default_factory=list)
total_codes_checked: int = 0
source: str = "GP_SNOMED" # GP_SNOMED | NONE
error_message: Optional[str] = None
@dataclass
class DrugIndicationMatchRate:
"""Match rate statistics for a drug's indication validation."""
drug_name: str
total_patients: int
patients_with_indication: int
patients_without_indication: int
match_rate: float # 0.0 to 1.0
clusters_checked: list[str] = field(default_factory=list)
sample_unmatched: list[str] = field(default_factory=list) # Sample patient IDs
def get_drug_clusters(
drug_name: str,
db_manager: Optional[DatabaseManager] = None
) -> list[dict]:
"""
Get all SNOMED cluster mappings for a drug from local SQLite.
Args:
drug_name: Drug name to look up (case-insensitive)
db_manager: Optional DatabaseManager (defaults to default_db_manager)
Returns:
List of dicts with keys: drug_name, indication, cluster_id,
cluster_description, nice_ta_reference
"""
if db_manager is None:
db_manager = default_db_manager
query = """
SELECT drug_name, indication, cluster_id, cluster_description, nice_ta_reference
FROM ref_drug_indication_clusters
WHERE UPPER(drug_name) = UPPER(?)
ORDER BY indication, cluster_id
"""
try:
with db_manager.get_connection() as conn:
cursor = conn.execute(query, (drug_name,))
rows = cursor.fetchall()
results = []
for row in rows:
results.append({
"drug_name": row["drug_name"],
"indication": row["indication"],
"cluster_id": row["cluster_id"],
"cluster_description": row["cluster_description"],
"nice_ta_reference": row["nice_ta_reference"],
})
logger.debug(f"Found {len(results)} cluster mappings for drug '{drug_name}'")
return results
except Exception as e:
logger.error(f"Error getting clusters for drug '{drug_name}': {e}")
return []
def get_drug_cluster_ids(
drug_name: str,
db_manager: Optional[DatabaseManager] = None
) -> list[str]:
"""
Get unique cluster IDs for a drug.
Args:
drug_name: Drug name to look up
db_manager: Optional DatabaseManager
Returns:
List of unique cluster IDs
"""
clusters = get_drug_clusters(drug_name, db_manager)
return list(set(c["cluster_id"] for c in clusters))
def get_cluster_snomed_codes(
cluster_id: str,
connector: Optional[SnowflakeConnector] = None,
use_cache: bool = True,
) -> ClusterSnomedCodes:
"""
Get all SNOMED codes for a cluster from Snowflake.
Queries the ClinicalCodingClusterSnomedCodes table to get all SNOMED codes
that belong to the specified cluster.
Args:
cluster_id: Cluster ID to look up (e.g., 'RARTH_COD', 'PSORIASIS_COD')
connector: Optional SnowflakeConnector (defaults to singleton)
use_cache: Whether to use cached results (default True)
Returns:
ClusterSnomedCodes with list of SNOMED codes and descriptions
"""
if not SNOWFLAKE_AVAILABLE:
logger.warning("Snowflake connector not available")
return ClusterSnomedCodes(cluster_id=cluster_id, cluster_description="")
if not is_snowflake_configured():
logger.warning("Snowflake not configured - cannot get cluster codes")
return ClusterSnomedCodes(cluster_id=cluster_id, cluster_description="")
# Check cache first
cache_key = f"cluster_snomed_{cluster_id}"
if use_cache and is_cache_enabled():
cache = get_cache()
cached = cache.get(cache_key)
if cached is not None and len(cached) > 0:
logger.debug(f"Using cached SNOMED codes for cluster '{cluster_id}'")
cached_dict = cached[0] # First element is our data dict
return ClusterSnomedCodes(
cluster_id=cluster_id,
cluster_description=str(cached_dict.get("description", "")),
snomed_codes=list(cached_dict.get("codes", [])),
snomed_descriptions=dict(cached_dict.get("descriptions", {})),
)
if connector is None:
connector = get_connector()
query = '''
SELECT DISTINCT
"Cluster_ID",
"Cluster_Description",
"SNOMEDCode",
"SNOMEDDescription"
FROM DATA_HUB.PHM."ClinicalCodingClusterSnomedCodes"
WHERE "Cluster_ID" = %s
ORDER BY "SNOMEDCode"
'''
try:
results = connector.execute_dict(query, (cluster_id,))
if not results:
logger.warning(f"No SNOMED codes found for cluster '{cluster_id}'")
return ClusterSnomedCodes(cluster_id=cluster_id, cluster_description="")
codes = []
descriptions = {}
description = results[0].get("Cluster_Description", "") if results else ""
for row in results:
code = row.get("SNOMEDCode")
if code:
codes.append(code)
descriptions[code] = row.get("SNOMEDDescription", "")
logger.info(f"Found {len(codes)} SNOMED codes for cluster '{cluster_id}'")
# Cache the results (using query-based cache with fake params)
if use_cache and is_cache_enabled():
cache = get_cache()
cache_data = [{
"description": description,
"codes": codes,
"descriptions": descriptions,
}]
cache.set(cache_key, None, cache_data) # type: ignore[arg-type]
return ClusterSnomedCodes(
cluster_id=cluster_id,
cluster_description=description,
snomed_codes=codes,
snomed_descriptions=descriptions,
)
except Exception as e:
logger.error(f"Error getting SNOMED codes for cluster '{cluster_id}': {e}")
return ClusterSnomedCodes(cluster_id=cluster_id, cluster_description="")
def patient_has_indication(
patient_pseudonym: str,
cluster_ids: list[str],
connector: Optional[SnowflakeConnector] = None,
before_date: Optional[date] = None,
) -> tuple[bool, Optional[str], Optional[str], Optional[str]]:
"""
Check if a patient has any SNOMED codes from the specified clusters in GP records.
Args:
patient_pseudonym: Patient's pseudonymised NHS number
cluster_ids: List of cluster IDs to check against
connector: Optional SnowflakeConnector
before_date: Optional date - only check diagnoses before this date
Returns:
Tuple of (has_indication, matched_cluster_id, matched_snomed_code, matched_description)
"""
if not SNOWFLAKE_AVAILABLE or not is_snowflake_configured():
return False, None, None, None
if not cluster_ids:
return False, None, None, None
if connector is None:
connector = get_connector()
# Build placeholders for cluster IDs
placeholders = ", ".join(["%s"] * len(cluster_ids))
# Query to check if patient has any matching SNOMED code
query = f'''
SELECT
pc."SNOMEDCode",
cc."Cluster_ID",
cc."SNOMEDDescription"
FROM DATA_HUB.PHM."PrimaryCareClinicalCoding" pc
INNER JOIN DATA_HUB.PHM."ClinicalCodingClusterSnomedCodes" cc
ON pc."SNOMEDCode" = cc."SNOMEDCode"
WHERE pc."PatientPseudonym" = %s
AND cc."Cluster_ID" IN ({placeholders})
'''
params = [patient_pseudonym] + cluster_ids
if before_date:
query += ' AND pc."EventDateTime" < %s'
params.append(before_date.isoformat())
query += ' LIMIT 1'
try:
results = connector.execute_dict(query, tuple(params))
if results:
row = results[0]
return (
True,
row.get("Cluster_ID"),
row.get("SNOMEDCode"),
row.get("SNOMEDDescription"),
)
return False, None, None, None
except Exception as e:
logger.error(f"Error checking indication for patient '{patient_pseudonym}': {e}")
return False, None, None, None
def validate_indication(
patient_pseudonym: str,
drug_name: str,
connector: Optional[SnowflakeConnector] = None,
db_manager: Optional[DatabaseManager] = None,
before_date: Optional[date] = None,
) -> IndicationValidationResult:
"""
Validate that a patient has an appropriate indication for a drug.
Full validation workflow:
1. Get drug's valid indication clusters from local mapping
2. Check if patient has any matching SNOMED codes in GP records
3. Return detailed validation result
Args:
patient_pseudonym: Patient's pseudonymised NHS number
drug_name: Drug name to validate indication for
connector: Optional SnowflakeConnector
db_manager: Optional DatabaseManager
before_date: Optional date - only check diagnoses before this date
Returns:
IndicationValidationResult with validation details
"""
result = IndicationValidationResult(
patient_pseudonym=patient_pseudonym,
drug_name=drug_name,
has_valid_indication=False,
)
# Step 1: Get drug's cluster mappings
cluster_ids = get_drug_cluster_ids(drug_name, db_manager)
if not cluster_ids:
result.error_message = f"No cluster mappings found for drug '{drug_name}'"
result.source = "NONE"
return result
result.checked_clusters = cluster_ids
# Step 2: Check Snowflake availability
if not SNOWFLAKE_AVAILABLE:
result.error_message = "Snowflake connector not installed"
result.source = "NONE"
return result
if not is_snowflake_configured():
result.error_message = "Snowflake not configured"
result.source = "NONE"
return result
# Step 3: Check patient GP records
has_indication, matched_cluster, matched_code, matched_desc = patient_has_indication(
patient_pseudonym=patient_pseudonym,
cluster_ids=cluster_ids,
connector=connector,
before_date=before_date,
)
result.has_valid_indication = has_indication
result.matched_cluster_id = matched_cluster
result.matched_snomed_code = matched_code
result.matched_snomed_description = matched_desc
result.source = "GP_SNOMED" if has_indication else "NONE"
return result
def get_indication_match_rate(
drug_name: str,
patient_pseudonyms: list[str],
connector: Optional[SnowflakeConnector] = None,
db_manager: Optional[DatabaseManager] = None,
sample_unmatched_count: int = 10,
) -> DrugIndicationMatchRate:
"""
Calculate indication match rate for a drug across a list of patients.
Args:
drug_name: Drug name to check
patient_pseudonyms: List of patient pseudonymised NHS numbers
connector: Optional SnowflakeConnector
db_manager: Optional DatabaseManager
sample_unmatched_count: Number of unmatched patient IDs to include in sample
Returns:
DrugIndicationMatchRate with match statistics
"""
if connector is None and SNOWFLAKE_AVAILABLE and is_snowflake_configured():
connector = get_connector()
cluster_ids = get_drug_cluster_ids(drug_name, db_manager)
total = len(patient_pseudonyms)
matched = 0
unmatched = 0
sample_unmatched: list[str] = []
if not cluster_ids:
logger.warning(f"No cluster mappings for drug '{drug_name}' - all patients will be unmatched")
return DrugIndicationMatchRate(
drug_name=drug_name,
total_patients=total,
patients_with_indication=0,
patients_without_indication=total,
match_rate=0.0,
clusters_checked=[],
sample_unmatched=patient_pseudonyms[:sample_unmatched_count],
)
for i, pseudonym in enumerate(patient_pseudonyms):
if i > 0 and i % 100 == 0:
logger.info(f"Validating indications: {i}/{total} ({100*i/total:.1f}%)")
has_indication, _, _, _ = patient_has_indication(
patient_pseudonym=pseudonym,
cluster_ids=cluster_ids,
connector=connector,
)
if has_indication:
matched += 1
else:
unmatched += 1
if len(sample_unmatched) < sample_unmatched_count:
sample_unmatched.append(pseudonym)
match_rate = matched / total if total > 0 else 0.0
logger.info(f"Indication match rate for '{drug_name}': {100*match_rate:.1f}% ({matched}/{total})")
return DrugIndicationMatchRate(
drug_name=drug_name,
total_patients=total,
patients_with_indication=matched,
patients_without_indication=unmatched,
match_rate=match_rate,
clusters_checked=cluster_ids,
sample_unmatched=sample_unmatched,
)
def batch_validate_indications(
patient_drug_pairs: list[tuple[str, str]],
connector: Optional[SnowflakeConnector] = None,
db_manager: Optional[DatabaseManager] = None,
progress_callback: Optional[Callable[[int, int], None]] = None,
) -> list[IndicationValidationResult]:
"""
Validate indications for multiple patient-drug pairs efficiently.
Args:
patient_drug_pairs: List of (patient_pseudonym, drug_name) tuples
connector: Optional SnowflakeConnector
db_manager: Optional DatabaseManager
progress_callback: Optional callback(current, total) for progress updates
Returns:
List of IndicationValidationResult for each pair
"""
results = []
total = len(patient_drug_pairs)
# Cache cluster lookups by drug
drug_clusters_cache = {}
for i, (pseudonym, drug_name) in enumerate(patient_drug_pairs):
if progress_callback:
progress_callback(i + 1, total)
# Get clusters from cache or lookup
drug_upper = drug_name.upper()
if drug_upper not in drug_clusters_cache:
drug_clusters_cache[drug_upper] = get_drug_cluster_ids(drug_name, db_manager)
cluster_ids = drug_clusters_cache[drug_upper]
if not cluster_ids:
results.append(IndicationValidationResult(
patient_pseudonym=pseudonym,
drug_name=drug_name,
has_valid_indication=False,
source="NONE",
error_message=f"No cluster mappings for drug '{drug_name}'",
))
continue
# Check patient indication
has_indication, matched_cluster, matched_code, matched_desc = patient_has_indication(
patient_pseudonym=pseudonym,
cluster_ids=cluster_ids,
connector=connector,
)
results.append(IndicationValidationResult(
patient_pseudonym=pseudonym,
drug_name=drug_name,
has_valid_indication=has_indication,
matched_cluster_id=matched_cluster,
matched_snomed_code=matched_code,
matched_snomed_description=matched_desc,
checked_clusters=cluster_ids,
source="GP_SNOMED" if has_indication else "NONE",
))
matched_count = sum(1 for r in results if r.has_valid_indication)
logger.info(f"Batch validation complete: {matched_count}/{total} ({100*matched_count/total:.1f}%) with valid indications")
return results
def get_available_clusters(
connector: Optional[SnowflakeConnector] = None,
) -> list[dict]:
"""
Get list of all available SNOMED clusters from Snowflake.
Returns:
List of dicts with cluster_id, cluster_description, code_count
"""
if not SNOWFLAKE_AVAILABLE or not is_snowflake_configured():
logger.warning("Snowflake not available - cannot list clusters")
return []
if connector is None:
connector = get_connector()
query = '''
SELECT
"Cluster_ID",
"Cluster_Description",
COUNT(DISTINCT "SNOMEDCode") as code_count
FROM DATA_HUB.PHM."ClinicalCodingClusterSnomedCodes"
GROUP BY "Cluster_ID", "Cluster_Description"
ORDER BY "Cluster_ID"
'''
try:
results = connector.execute_dict(query)
clusters = []
for row in results:
clusters.append({
"cluster_id": row.get("Cluster_ID"),
"cluster_description": row.get("Cluster_Description"),
"code_count": row.get("code_count", 0),
})
logger.info(f"Found {len(clusters)} available SNOMED clusters")
return clusters
except Exception as e:
logger.error(f"Error getting available clusters: {e}")
return []
# Export public API
__all__ = [
"ClusterSnomedCodes",
"IndicationValidationResult",
"DrugIndicationMatchRate",
"get_drug_clusters",
"get_drug_cluster_ids",
"get_cluster_snomed_codes",
"patient_has_indication",
"validate_indication",
"get_indication_match_rate",
"batch_validate_indications",
"get_available_clusters",
]
+399
View File
@@ -0,0 +1,399 @@
"""
Data loader abstractions for NHS High-Cost Drug Patient Pathway Analysis Tool.
Provides a unified interface for loading patient intervention data from:
- CSV/Parquet files (current behavior)
- SQLite database (new, faster approach)
- Snowflake (future, direct from warehouse)
The DataLoader ABC defines the contract for all loader implementations.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import date
from pathlib import Path
from typing import Optional
import pandas as pd
from core import PathConfig, default_paths
from core.logging_config import get_logger
logger = get_logger(__name__)
@dataclass
class LoadResult:
"""Result of a data load operation.
Attributes:
df: The loaded DataFrame with processed patient intervention data
source: Description of the data source (e.g., "csv:/path/to/file.csv", "sqlite:fact_interventions")
row_count: Number of rows loaded
columns: List of column names in the DataFrame
load_time_seconds: Time taken to load the data
"""
df: pd.DataFrame
source: str
row_count: int
columns: list[str] = field(default_factory=list)
load_time_seconds: float = 0.0
def __post_init__(self):
if not self.columns:
self.columns = list(self.df.columns)
# Expected columns in a processed DataFrame
# These are the columns that generate_graph() expects to receive
REQUIRED_COLUMNS = [
"UPID", # Unique Patient ID (Provider Code prefix + PersonKey)
"Drug Name", # Standardized drug name
"Intervention Date", # Date of intervention
"Price Actual", # Cost of intervention
"OrganisationName", # NHS Trust name
"Directory", # Medical specialty/directory
"Provider Code", # NHS provider code
"PersonKey", # Patient identifier within provider
]
# Additional columns that are useful but not strictly required
OPTIONAL_COLUMNS = [
"UPIDTreatment", # UPID + Drug Name combo (created by generate_graph)
"Treatment Function Code", # NHS treatment function code
"Additional Detail 1",
"Additional Detail 2",
"Additional Detail 3",
"Additional Detail 4",
"Additional Detail 5",
]
class DataLoader(ABC):
"""Abstract base class for data loaders.
All data loaders must implement the load() method which returns
a DataFrame ready for use by generate_graph().
The returned DataFrame must contain REQUIRED_COLUMNS at minimum.
"""
@abstractmethod
def load(self) -> LoadResult:
"""Load and process patient intervention data.
Returns:
LoadResult containing the processed DataFrame and metadata.
The DataFrame must contain all REQUIRED_COLUMNS.
Raises:
FileNotFoundError: If the data source doesn't exist
ValueError: If the data is malformed or missing required columns
"""
pass
@abstractmethod
def validate_source(self) -> tuple[bool, str]:
"""Check if the data source is valid and accessible.
Returns:
Tuple of (is_valid, message).
If is_valid is False, message explains the issue.
"""
pass
@property
@abstractmethod
def source_description(self) -> str:
"""Human-readable description of the data source."""
pass
def validate_dataframe(self, df: pd.DataFrame) -> tuple[bool, list[str]]:
"""Validate that a DataFrame has all required columns.
Args:
df: DataFrame to validate
Returns:
Tuple of (is_valid, missing_columns).
If is_valid is False, missing_columns lists what's missing.
"""
missing = [col for col in REQUIRED_COLUMNS if col not in df.columns]
return len(missing) == 0, missing
class FileDataLoader(DataLoader):
"""Loads data from CSV or Parquet files.
This replicates the current behavior of dashboard_gui.main():
1. Read CSV or Parquet file
2. Apply patient_id() transformation
3. Convert dates
4. Apply drug_names() standardization
5. Clean organization names
6. Apply department_identification()
Args:
file_path: Path to the CSV or Parquet file
paths: PathConfig for reference data file locations (uses default_paths if None)
"""
def __init__(
self,
file_path: Path | str,
paths: Optional[PathConfig] = None,
):
self.file_path = Path(file_path)
self.paths = paths or default_paths
def validate_source(self) -> tuple[bool, str]:
"""Check if the file exists and has a supported extension."""
if not self.file_path.exists():
return False, f"File not found: {self.file_path}"
ext = self.file_path.suffix.lower()
if ext not in ('.csv', '.parquet'):
return False, f"Unsupported file type: {ext}. Must be .csv or .parquet"
return True, "OK"
@property
def source_description(self) -> str:
return f"file:{self.file_path}"
def load(self) -> LoadResult:
"""Load and process data from CSV or Parquet file.
Applies the same transformation pipeline as the original
dashboard_gui.main() function.
"""
import time
from tools import data
start_time = time.time()
# Validate source before loading
is_valid, msg = self.validate_source()
if not is_valid:
raise FileNotFoundError(msg)
# Read file based on extension
ext = self.file_path.suffix.lower()
logger.info(f"Reading {ext} file: {self.file_path}")
if ext == '.csv':
df_raw = pd.read_csv(self.file_path, low_memory=False)
else: # .parquet
df_raw = pd.read_parquet(self.file_path)
logger.info(f"File read successfully. {len(df_raw)} rows.")
# Apply transformations (same as dashboard_gui.main())
df = data.patient_id(df_raw)
logger.info("Patient ID processing complete.")
df['Intervention Date'] = pd.to_datetime(df['Intervention Date'], format="%Y-%m-%d")
logger.info("Date conversion complete.")
# Preserve original drug name before standardization (for SQLite storage)
df['Drug Name Raw'] = df['Drug Name'].copy()
df = data.drug_names(df, self.paths)
logger.info("Drug name processing complete.")
df['OrganisationName'] = df['OrganisationName'].str.replace(',', '')
logger.info("Organisation name cleaning complete.")
df = data.department_identification(df, self.paths)
logger.info("Department identification complete.")
# Validate result
is_valid, missing = self.validate_dataframe(df)
if not is_valid:
raise ValueError(f"Processed DataFrame missing required columns: {missing}")
load_time = time.time() - start_time
logger.info(f"Data loading complete. {len(df)} rows in {load_time:.2f}s")
return LoadResult(
df=df,
source=self.source_description,
row_count=len(df),
load_time_seconds=load_time,
)
class SQLiteDataLoader(DataLoader):
"""Loads data from SQLite fact_interventions table.
This provides faster loading by reading pre-processed data from SQLite
instead of re-processing CSV files each time.
The SQLite database must have been populated by the migration scripts.
Args:
db_path: Path to the SQLite database (uses default if None)
date_range: Optional tuple of (start_date, end_date) to filter data
trusts: Optional list of trust names to filter
drugs: Optional list of drug names to filter
directories: Optional list of directories to filter
"""
def __init__(
self,
db_path: Optional[Path | str] = None,
date_range: Optional[tuple[date, date]] = None,
trusts: Optional[list[str]] = None,
drugs: Optional[list[str]] = None,
directories: Optional[list[str]] = None,
):
from data_processing.database import default_db_config
self.db_path = Path(db_path) if db_path else Path(default_db_config.db_path)
self.date_range = date_range
self.trusts = trusts
self.drugs = drugs
self.directories = directories
def validate_source(self) -> tuple[bool, str]:
"""Check if the database exists and has the fact_interventions table."""
if not self.db_path.exists():
return False, f"Database not found: {self.db_path}"
# Check if fact_interventions table exists
from data_processing.database import DatabaseManager, DatabaseConfig
config = DatabaseConfig(db_path=self.db_path)
manager = DatabaseManager(config)
if not manager.table_exists("fact_interventions"):
return False, "fact_interventions table not found in database"
count = manager.get_table_count("fact_interventions")
if count == 0:
return False, "fact_interventions table is empty"
return True, f"OK ({count:,} rows available)"
@property
def source_description(self) -> str:
return f"sqlite:{self.db_path}"
def load(self) -> LoadResult:
"""Load data from SQLite fact_interventions table.
Maps SQLite column names to the expected DataFrame column names.
Applies optional filters for date range, trusts, drugs, directories.
"""
import time
from data_processing.database import DatabaseManager, DatabaseConfig
start_time = time.time()
# Validate source
is_valid, msg = self.validate_source()
if not is_valid:
raise FileNotFoundError(msg)
logger.info(f"Loading data from SQLite: {self.db_path}")
# Build query with optional filters
query = """
SELECT
upid AS "UPID",
provider_code AS "Provider Code",
person_key AS "PersonKey",
drug_name_std AS "Drug Name",
intervention_date AS "Intervention Date",
price_actual AS "Price Actual",
org_name AS "OrganisationName",
directory AS "Directory",
treatment_function_code AS "Treatment Function Code",
additional_detail_1 AS "Additional Detail 1",
additional_detail_2 AS "Additional Detail 2",
additional_detail_3 AS "Additional Detail 3",
additional_detail_4 AS "Additional Detail 4",
additional_detail_5 AS "Additional Detail 5"
FROM fact_interventions
WHERE 1=1
"""
params = []
if self.date_range:
start, end = self.date_range
query += " AND intervention_date >= ? AND intervention_date < ?"
params.extend([str(start), str(end)])
if self.trusts:
placeholders = ','.join('?' * len(self.trusts))
query += f" AND org_name IN ({placeholders})"
params.extend(self.trusts)
if self.drugs:
placeholders = ','.join('?' * len(self.drugs))
query += f" AND drug_name_std IN ({placeholders})"
params.extend(self.drugs)
if self.directories:
placeholders = ','.join('?' * len(self.directories))
query += f" AND directory IN ({placeholders})"
params.extend(self.directories)
# Execute query
config = DatabaseConfig(db_path=self.db_path)
manager = DatabaseManager(config)
with manager.get_connection() as conn:
df = pd.read_sql_query(query, conn, params=params)
# Convert intervention_date to datetime
df['Intervention Date'] = pd.to_datetime(df['Intervention Date'])
logger.info(f"Loaded {len(df)} rows from SQLite")
# Validate result
is_valid, missing = self.validate_dataframe(df)
if not is_valid:
raise ValueError(f"SQLite data missing required columns: {missing}")
load_time = time.time() - start_time
logger.info(f"SQLite data loading complete. {len(df)} rows in {load_time:.2f}s")
return LoadResult(
df=df,
source=self.source_description,
row_count=len(df),
load_time_seconds=load_time,
)
def get_loader(
source: str | Path,
paths: Optional[PathConfig] = None,
**kwargs
) -> DataLoader:
"""Factory function to create the appropriate DataLoader.
Args:
source: Either a file path (CSV/Parquet) or "sqlite" for database
paths: PathConfig for reference data (used by FileDataLoader)
**kwargs: Additional arguments passed to the loader constructor
Returns:
Appropriate DataLoader instance
Examples:
>>> loader = get_loader("data/activity.csv")
>>> loader = get_loader("data/activity.parquet")
>>> loader = get_loader("sqlite")
>>> loader = get_loader("sqlite", date_range=(date(2024, 1, 1), date(2024, 12, 31)))
"""
source_str = str(source).lower()
if source_str == "sqlite":
return SQLiteDataLoader(**kwargs)
# Assume it's a file path
path = Path(source)
return FileDataLoader(file_path=path, paths=paths)
+593
View File
@@ -0,0 +1,593 @@
"""
Database migration script for NHS High-Cost Drug Patient Pathway Analysis Tool.
Provides functions to initialize the SQLite database schema and CLI interface
for running migrations from the command line.
Usage:
# Initialize database (creates all tables)
python -m data_processing.migrate
# Drop existing tables and reinitialize
python -m data_processing.migrate --drop-existing
# Show current database status
python -m data_processing.migrate --status
# Migrate all reference data from CSV files
python -m data_processing.migrate --reference-data
# Migrate reference data with verification
python -m data_processing.migrate --reference-data --verify
"""
import argparse
import sys
from pathlib import Path
from typing import Optional
from core.logging_config import setup_logging, get_logger
from data_processing.database import DatabaseManager, DatabaseConfig
from core import PathConfig, default_paths
from data_processing.schema import (
create_all_tables,
drop_all_tables,
verify_all_tables_exist,
get_all_table_counts,
)
from data_processing.reference_data import (
MigrationResult,
migrate_drug_names,
migrate_organizations,
migrate_directories,
migrate_drug_directory_map,
migrate_drug_indication_clusters,
verify_drug_names_migration,
verify_organizations_migration,
verify_directories_migration,
verify_drug_directory_map_migration,
verify_drug_indication_clusters_migration,
)
from data_processing.patient_data import (
load_patient_data,
refresh_patient_treatment_summary,
get_patient_data_stats,
verify_mv_consistency,
)
logger = get_logger(__name__)
def initialize_database(
db_manager: Optional[DatabaseManager] = None,
drop_existing: bool = False,
confirm_drop: bool = True
) -> bool:
"""
Initialize the database with all required tables.
Creates all tables defined in the schema (reference tables, fact tables,
materialized views, and file tracking tables). Uses IF NOT EXISTS so
safe to run multiple times.
Args:
db_manager: DatabaseManager instance. Uses default if not provided.
drop_existing: If True, drops all existing tables before creating.
confirm_drop: If True and drop_existing=True, prompts for confirmation.
Set to False for non-interactive use.
Returns:
True if initialization succeeded, False otherwise.
"""
if db_manager is None:
db_manager = DatabaseManager()
logger.info(f"Initializing database at: {db_manager.db_path}")
# Handle drop existing with confirmation
if drop_existing:
if confirm_drop:
print(f"\nWARNING: This will delete ALL data from the database:")
print(f" {db_manager.db_path}\n")
response = input("Are you sure you want to continue? (yes/no): ")
if response.lower() not in ("yes", "y"):
print("Operation cancelled.")
return False
if db_manager.exists:
logger.warning("Dropping existing tables...")
with db_manager.get_connection() as conn:
drop_all_tables(conn)
conn.commit()
logger.info("Existing tables dropped")
else:
logger.info("Database does not exist yet, nothing to drop")
# Create all tables
try:
with db_manager.get_transaction() as conn:
create_all_tables(conn)
except Exception as e:
logger.error(f"Failed to create tables: {e}")
return False
# Verify all tables were created
with db_manager.get_connection() as conn:
missing = verify_all_tables_exist(conn)
if missing:
logger.error(f"Table creation failed. Missing tables: {missing}")
return False
logger.info("All tables created successfully")
return True
def migrate_all_reference_data(
db_manager: Optional[DatabaseManager] = None,
paths: Optional[PathConfig] = None,
verify: bool = False
) -> tuple[bool, list[MigrationResult]]:
"""
Run all reference data migrations from CSV files to SQLite tables.
Migrations are run in order:
1. Drug names (drugnames.csv → ref_drug_names)
2. Organizations (org_codes.csv → ref_organizations)
3. Directories (directory_list.csv → ref_directories)
4. Drug-directory mappings (drug_directory_list.csv → ref_drug_directory_map)
Args:
db_manager: DatabaseManager instance. Uses default if not provided.
paths: PathConfig instance for locating CSV files. Uses default if not provided.
verify: If True, runs verification after each migration.
Returns:
Tuple of (all_success: bool, results: list of MigrationResult)
"""
if db_manager is None:
db_manager = DatabaseManager()
if paths is None:
paths = default_paths
results: list[MigrationResult] = []
all_success = True
# Define migrations in order
# Note: drug_indication_clusters uses a different signature (csv_path instead of paths)
migrations = [
("Drug names", migrate_drug_names, verify_drug_names_migration if verify else None, True),
("Organizations", migrate_organizations, verify_organizations_migration if verify else None, True),
("Directories", migrate_directories, verify_directories_migration if verify else None, True),
("Drug-directory map", migrate_drug_directory_map, verify_drug_directory_map_migration if verify else None, True),
("Drug indication clusters", migrate_drug_indication_clusters, verify_drug_indication_clusters_migration if verify else None, False),
]
logger.info(f"Starting reference data migrations ({len(migrations)} tables)")
for name, migrate_fn, verify_fn, uses_paths in migrations:
logger.info(f"Migrating: {name}...")
# Run migration (some use paths parameter, some use csv_path)
if uses_paths:
result = migrate_fn(db_manager=db_manager, paths=paths) # type: ignore[operator]
else:
# Drug indication clusters uses csv_path instead of paths
result = migrate_fn(db_manager=db_manager) # type: ignore[operator]
results.append(result)
if not result.success:
logger.error(f"Migration failed: {name} - {result.error_message}")
all_success = False
continue
logger.info(f" {result}")
# Run verification if requested
if verify_fn is not None:
logger.info(f" Verifying {name}...")
if uses_paths:
verified, verify_msg = verify_fn(db_manager=db_manager, paths=paths) # type: ignore[call-arg]
else:
verified, verify_msg = verify_fn(db_manager=db_manager) # type: ignore[call-arg]
if verified:
logger.info(f" OK: {verify_msg}")
else:
logger.error(f" FAILED: Verification failed: {verify_msg}")
all_success = False
# Summary
successful = sum(1 for r in results if r.success)
logger.info(f"Reference data migrations complete: {successful}/{len(results)} succeeded")
return all_success, results
def print_migration_summary(results: list[MigrationResult]) -> None:
"""Print a summary of migration results to stdout."""
print("\n=== Reference Data Migration Summary ===\n")
for result in results:
status = "[OK]" if result.success else "[FAILED]"
print(f"{status} {result.table_name}")
if result.success:
print(f" Read: {result.rows_read}, Inserted: {result.rows_inserted}, Skipped: {result.rows_skipped}")
else:
print(f" Error: {result.error_message}")
successful = sum(1 for r in results if r.success)
print(f"\nTotal: {successful}/{len(results)} migrations succeeded")
print()
def create_progress_reporter(description: str = "Loading", width: int = 40):
"""
Create a progress callback that prints a progress bar to stdout.
Args:
description: Label to show before the progress bar.
width: Width of the progress bar in characters.
Returns:
Callback function(current, total) that prints progress.
"""
last_percent = [-1] # Use list to allow mutation in closure
def report_progress(current: int, total: int) -> None:
"""Print a progress bar showing current/total progress."""
if total == 0:
percent = 100
else:
percent = int(100 * current / total)
# Only update display when percentage changes (avoid excessive output)
if percent == last_percent[0]:
return
last_percent[0] = percent
filled = int(width * current / total) if total > 0 else width
bar = "=" * filled + "-" * (width - filled)
# Use carriage return to overwrite the line
sys.stdout.write(f"\r{description}: [{bar}] {percent:3d}% ({current:,}/{total:,})")
sys.stdout.flush()
# Print newline when complete
if current >= total:
print()
return report_progress
def load_patient_data_cli(
file_path: Path,
db_manager: Optional[DatabaseManager] = None,
paths: Optional[PathConfig] = None,
force: bool = False,
refresh_mv: bool = True
) -> bool:
"""
Load patient data from file with CLI progress reporting.
Args:
file_path: Path to CSV or Parquet file.
db_manager: DatabaseManager instance. Uses default if not provided.
paths: PathConfig for reference data. Uses default if not provided.
force: If True, re-process even if file hash matches.
refresh_mv: If True, refresh the materialized view after loading.
Returns:
True if loading succeeded, False otherwise.
"""
if db_manager is None:
db_manager = DatabaseManager()
if paths is None:
paths = default_paths
print(f"\n=== Loading Patient Data ===\n")
print(f"File: {file_path}")
# Check file exists
if not file_path.exists():
print(f"ERROR: File not found: {file_path}")
return False
# Calculate and display file info
file_size_mb = file_path.stat().st_size / (1024 * 1024)
print(f"Size: {file_size_mb:.1f} MB")
print()
# Create progress callback
progress_callback = create_progress_reporter("Loading rows", width=40)
# Load the data
result = load_patient_data(
file_path=file_path,
db_manager=db_manager,
paths=paths,
batch_size=5000,
force=force,
progress_callback=progress_callback
)
# Print result
print()
if result.was_already_processed:
print("File already processed (same hash). Skipping.")
print(f"Use --force to re-process.")
elif result.success:
print(f"Loaded {result.rows_inserted:,} rows in {result.load_time_seconds:.1f}s")
if result.rows_skipped > 0:
print(f"Skipped {result.rows_skipped:,} rows (missing UPID or date)")
else:
print(f"FAILED: {result.error_message}")
return False
# Refresh materialized view if requested
if refresh_mv and result.success and not result.was_already_processed:
print()
print("Refreshing materialized view...")
mv_progress = create_progress_reporter("Processing patients", width=40)
mv_result = refresh_patient_treatment_summary(
db_manager=db_manager,
progress_callback=mv_progress
)
if mv_result.success:
print(f"MV refreshed: {mv_result.patients_processed:,} patients in {mv_result.refresh_time_seconds:.1f}s")
# Verify consistency
consistent, msg = verify_mv_consistency(db_manager)
if consistent:
print(f"MV verification: OK")
else:
print(f"MV verification: FAILED - {msg}")
else:
print(f"MV refresh FAILED: {mv_result.error_message}")
# Print summary statistics
print()
print("=== Patient Data Summary ===")
stats = get_patient_data_stats(db_manager)
print(f" Total rows: {stats['total_rows']:,}")
print(f" Unique patients: {stats['unique_patients']:,}")
print(f" Unique drugs: {stats['unique_drugs']:,}")
print(f" Unique organizations: {stats['unique_organizations']:,}")
if stats['date_range'][0] and stats['date_range'][1]:
print(f" Date range: {stats['date_range'][0]} to {stats['date_range'][1]}")
print()
return result.success
def get_database_status(db_manager: Optional[DatabaseManager] = None) -> dict:
"""
Get the current status of the database.
Returns:
Dictionary with database status information:
- exists: Whether the database file exists
- path: Path to the database file
- size_bytes: Size of database file (if exists)
- tables: Dictionary of table names to row counts
- missing_tables: List of expected tables that don't exist
"""
if db_manager is None:
db_manager = DatabaseManager()
status = {
"exists": db_manager.exists,
"path": str(db_manager.db_path),
"size_bytes": None,
"tables": {},
"missing_tables": [],
}
if db_manager.exists:
status["size_bytes"] = db_manager.db_path.stat().st_size
with db_manager.get_connection() as conn:
status["missing_tables"] = verify_all_tables_exist(conn)
# Get counts for existing tables
try:
status["tables"] = get_all_table_counts(conn)
except Exception as e:
logger.warning(f"Could not get table counts: {e}")
return status
def print_database_status(db_manager: Optional[DatabaseManager] = None) -> None:
"""Print database status to stdout in a human-readable format."""
status = get_database_status(db_manager)
print("\n=== Database Status ===\n")
print(f"Path: {status['path']}")
print(f"Exists: {status['exists']}")
if status["exists"]:
size_kb = (status["size_bytes"] or 0) / 1024
print(f"Size: {size_kb:.1f} KB")
if status["missing_tables"]:
print(f"\nMissing tables: {', '.join(status['missing_tables'])}")
else:
print("\nAll expected tables exist.")
if status["tables"]:
print("\nTable row counts:")
for table, count in sorted(status["tables"].items()):
print(f" {table}: {count:,} rows")
else:
print("\nDatabase does not exist. Run migration to create it.")
print()
def main():
"""CLI entry point for database migration."""
parser = argparse.ArgumentParser(
description="Initialize NHS Pathways Analysis SQLite database schema",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python -m data_processing.migrate # Initialize database
python -m data_processing.migrate --status # Show database status
python -m data_processing.migrate --drop-existing # Reset database
python -m data_processing.migrate --reference-data # Migrate reference data
python -m data_processing.migrate --reference-data --verify # With verification
python -m data_processing.migrate --load-patient-data data.parquet # Load patient data
python -m data_processing.migrate --load-patient-data data.csv --force # Force reload
python -m data_processing.migrate --db-path ./data/test.db # Custom path
"""
)
parser.add_argument(
"--status",
action="store_true",
help="Show current database status and exit"
)
parser.add_argument(
"--drop-existing",
action="store_true",
help="Drop all existing tables before creating (WARNING: deletes data)"
)
parser.add_argument(
"--reference-data",
action="store_true",
help="Migrate all reference data from CSV files to SQLite tables"
)
parser.add_argument(
"--verify",
action="store_true",
help="Verify migrated data matches CSV sources (use with --reference-data)"
)
parser.add_argument(
"--db-path",
type=Path,
help="Path to database file (default: ./data/pathways.db)"
)
parser.add_argument(
"--yes", "-y",
action="store_true",
help="Skip confirmation prompts (for non-interactive use)"
)
parser.add_argument(
"--verbose", "-v",
action="store_true",
help="Enable verbose logging"
)
parser.add_argument(
"--load-patient-data",
type=Path,
metavar="FILE",
help="Load patient data from CSV or Parquet file with progress reporting"
)
parser.add_argument(
"--force",
action="store_true",
help="Force re-processing even if file hash matches (use with --load-patient-data)"
)
parser.add_argument(
"--no-refresh-mv",
action="store_true",
help="Skip materialized view refresh after loading (use with --load-patient-data)"
)
args = parser.parse_args()
# Set up logging
log_level = "DEBUG" if args.verbose else "INFO"
setup_logging(level=log_level, simple_console=True)
# Create database manager with optional custom path
if args.db_path:
config = DatabaseConfig(db_path=args.db_path)
db_manager = DatabaseManager(config)
else:
db_manager = DatabaseManager()
# Handle --status
if args.status:
print_database_status(db_manager)
return 0
# Validate configuration
config_errors = db_manager.config.validate()
if config_errors:
for error in config_errors:
logger.error(error)
return 1
# Handle --reference-data (migrate reference data from CSV to SQLite)
if args.reference_data:
# Ensure database exists with tables first
if not db_manager.exists:
print("Database does not exist. Initializing schema first...")
success = initialize_database(db_manager=db_manager)
if not success:
print("\nDatabase initialization failed. Check logs for details.")
return 1
# Run reference data migrations
success, results = migrate_all_reference_data(
db_manager=db_manager,
paths=default_paths,
verify=args.verify
)
print_migration_summary(results)
print_database_status(db_manager)
if success:
print("Reference data migration completed successfully.")
return 0
else:
print("Reference data migration completed with errors. Check logs for details.")
return 1
# Handle --load-patient-data (load patient data from CSV/Parquet)
if args.load_patient_data:
# Ensure database exists with tables first
if not db_manager.exists:
print("Database does not exist. Initializing schema first...")
success = initialize_database(db_manager=db_manager)
if not success:
print("\nDatabase initialization failed. Check logs for details.")
return 1
# Load patient data with progress reporting
success = load_patient_data_cli(
file_path=args.load_patient_data,
db_manager=db_manager,
paths=default_paths,
force=args.force,
refresh_mv=not args.no_refresh_mv
)
if success:
print("Patient data load completed successfully.")
return 0
else:
print("Patient data load failed. Check logs for details.")
return 1
# Run schema migration (default behavior)
success = initialize_database(
db_manager=db_manager,
drop_existing=args.drop_existing,
confirm_drop=not args.yes
)
if success:
print("\nDatabase initialized successfully.")
print_database_status(db_manager)
return 0
else:
print("\nDatabase initialization failed. Check logs for details.")
return 1
if __name__ == "__main__":
sys.exit(main())
+890
View File
@@ -0,0 +1,890 @@
"""
Patient data migration functions for NHS High-Cost Drug Patient Pathway Analysis Tool.
Provides functions to load patient intervention data from CSV/Parquet files
into the SQLite fact_interventions table. Supports:
- Batch processing for large files
- File hash tracking for incremental updates
- Progress reporting during loading
"""
import hashlib
import os
import sqlite3
import time
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Callable, Optional
import pandas as pd
from core import PathConfig, default_paths
from core.logging_config import get_logger
from data_processing.database import DatabaseManager
logger = get_logger(__name__)
@dataclass
class PatientDataLoadResult:
"""Results from a patient data load operation."""
file_path: str
file_hash: str
rows_read: int
rows_inserted: int
rows_skipped: int
success: bool
error_message: Optional[str] = None
load_time_seconds: float = 0.0
was_already_processed: bool = False
def __str__(self) -> str:
if self.was_already_processed:
return f"{self.file_path}: Already processed (same hash)"
elif self.success:
return (
f"{self.file_path}: Loaded {self.rows_inserted:,} rows "
f"in {self.load_time_seconds:.1f}s"
)
else:
return f"{self.file_path}: FAILED - {self.error_message}"
def calculate_file_hash(file_path: Path) -> str:
"""
Calculate SHA256 hash of a file.
Uses chunked reading to handle large files efficiently.
Args:
file_path: Path to the file.
Returns:
Hex string of SHA256 hash.
"""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
sha256_hash.update(chunk)
return sha256_hash.hexdigest()
def check_file_processed(
conn: sqlite3.Connection,
file_path: str,
file_hash: str
) -> tuple[bool, Optional[str]]:
"""
Check if a file has already been processed with the same hash.
Args:
conn: Database connection.
file_path: Full path to the file.
file_hash: SHA256 hash of the file.
Returns:
Tuple of (is_processed, old_hash).
- If is_processed is True and old_hash == file_hash, file is unchanged.
- If is_processed is True and old_hash != file_hash, file has changed.
- If is_processed is False, file is new.
"""
cursor = conn.execute(
"SELECT file_hash, status FROM processed_files WHERE file_path = ?",
(file_path,)
)
result = cursor.fetchone()
if result is None:
return False, None
old_hash = result["file_hash"]
status = result["status"]
# Only consider it processed if status is success and hash matches
if status == "success" and old_hash == file_hash:
return True, old_hash
return False, old_hash
def record_file_processing_start(
conn: sqlite3.Connection,
file_path: str,
file_hash: str,
file_size: int,
file_modified: datetime
) -> None:
"""
Record that we're starting to process a file.
Args:
conn: Database connection.
file_path: Full path to the file.
file_hash: SHA256 hash of the file.
file_size: File size in bytes.
file_modified: File modification timestamp.
"""
file_name = Path(file_path).name
now = datetime.now().isoformat()
conn.execute("""
INSERT INTO processed_files (
file_path, file_name, file_hash, file_size_bytes,
file_modified_at, status, first_processed_at, last_processed_at
) VALUES (?, ?, ?, ?, ?, 'processing', ?, ?)
ON CONFLICT(file_path) DO UPDATE SET
file_hash = excluded.file_hash,
file_size_bytes = excluded.file_size_bytes,
file_modified_at = excluded.file_modified_at,
status = 'processing',
last_processed_at = excluded.last_processed_at,
error_message = NULL
""", (file_path, file_name, file_hash, file_size, file_modified.isoformat(), now, now))
def record_file_processing_complete(
conn: sqlite3.Connection,
file_path: str,
row_count: int,
duration_seconds: float,
success: bool,
error_message: Optional[str] = None
) -> None:
"""
Record that file processing has completed.
Args:
conn: Database connection.
file_path: Full path to the file.
row_count: Number of rows processed.
duration_seconds: Time taken to process.
success: Whether processing was successful.
error_message: Error message if failed.
"""
status = "success" if success else "error"
conn.execute("""
UPDATE processed_files
SET status = ?,
row_count = ?,
processing_duration_seconds = ?,
error_message = ?,
last_processed_at = ?
WHERE file_path = ?
""", (status, row_count, duration_seconds, error_message, datetime.now().isoformat(), file_path))
def load_dataframe_to_sqlite(
df: pd.DataFrame,
conn: sqlite3.Connection,
source_file: str,
batch_size: int = 5000,
progress_callback: Optional[Callable[[int, int], None]] = None
) -> int:
"""
Load a processed DataFrame into fact_interventions table.
Args:
df: Processed DataFrame with required columns (from FileDataLoader).
conn: Database connection.
source_file: Source file path for tracking.
batch_size: Number of rows to insert per batch.
progress_callback: Optional callback(rows_inserted, total_rows) for progress updates.
Returns:
Number of rows inserted.
"""
# Store the original drug names before processing (for rows where mapping doesn't exist)
# The drug_names() transformation sets Drug Name to NULL when no mapping exists.
# We need to preserve the original for those cases.
# Insert SQL columns - always include drug_name_raw
insert_columns = [
"upid", "provider_code", "person_key",
"drug_name_raw", "drug_name_std",
"intervention_date", "price_actual",
"org_name", "directory",
"treatment_function_code",
"additional_detail_1", "additional_detail_2", "additional_detail_3",
"additional_detail_4", "additional_detail_5",
"source_file"
]
placeholders = ",".join(["?"] * len(insert_columns))
insert_sql = f"""
INSERT INTO fact_interventions ({",".join(insert_columns)})
VALUES ({placeholders})
"""
rows_inserted = 0
rows_skipped = 0
total_rows = len(df)
# Process in batches
for batch_start in range(0, total_rows, batch_size):
batch_end = min(batch_start + batch_size, total_rows)
batch_df = df.iloc[batch_start:batch_end]
# Prepare batch data
batch_data = []
for _, row in batch_df.iterrows():
# Skip rows missing required fields
if pd.isna(row.get("UPID")) or pd.isna(row.get("Intervention Date")):
rows_skipped += 1
continue
# Get drug names - raw and standardized
drug_name_raw = row.get("Drug Name Raw") if "Drug Name Raw" in df.columns else None
drug_name_std = row.get("Drug Name")
# If drug_name_std is NULL, use the raw drug name (uppercase)
# This handles cases where the drug isn't in the drugnames.csv mapping
if pd.isna(drug_name_std):
if drug_name_raw is not None and not pd.isna(drug_name_raw):
drug_name_std = str(drug_name_raw).upper().strip()
else:
drug_name_std = "UNKNOWN"
# Also clean up raw drug name for storage
if drug_name_raw is not None and not pd.isna(drug_name_raw):
drug_name_raw = str(drug_name_raw).strip()
# Get other values with null handling
def get_value(col_name):
if col_name not in df.columns:
return None
val = row[col_name]
if pd.isna(val):
return None
elif hasattr(val, "strftime"):
return val.strftime("%Y-%m-%d")
return val
row_data = (
get_value("UPID"),
get_value("Provider Code"),
get_value("PersonKey"),
drug_name_raw,
drug_name_std,
get_value("Intervention Date"),
get_value("Price Actual") or 0,
get_value("OrganisationName"),
get_value("Directory"),
get_value("Treatment Function Code"),
get_value("Additional Detail 1"),
get_value("Additional Detail 2"),
get_value("Additional Detail 3"),
get_value("Additional Detail 4"),
get_value("Additional Detail 5"),
source_file
)
batch_data.append(row_data)
# Execute batch insert
conn.executemany(insert_sql, batch_data)
rows_inserted += len(batch_data)
# Report progress
if progress_callback:
progress_callback(rows_inserted, total_rows)
if rows_skipped > 0:
logger.info(f"Skipped {rows_skipped:,} rows with missing UPID or Intervention Date")
return rows_inserted
def delete_file_data(conn: sqlite3.Connection, source_file: str) -> int:
"""
Delete all data from a specific source file.
Used when re-processing a changed file.
Args:
conn: Database connection.
source_file: Source file path.
Returns:
Number of rows deleted.
"""
cursor = conn.execute(
"DELETE FROM fact_interventions WHERE source_file = ?",
(source_file,)
)
return cursor.rowcount
def load_patient_data(
file_path: Path | str,
db_manager: Optional[DatabaseManager] = None,
paths: Optional[PathConfig] = None,
batch_size: int = 5000,
force: bool = False,
progress_callback: Optional[Callable[[int, int], None]] = None
) -> PatientDataLoadResult:
"""
Load patient data from CSV/Parquet file into fact_interventions table.
This is the main entry point for loading patient data. It:
1. Calculates file hash to detect changes
2. Checks if file was already processed (skip if unchanged)
3. Loads and transforms data using FileDataLoader
4. Inserts data into SQLite in batches
5. Records processing status in processed_files table
Args:
file_path: Path to CSV or Parquet file.
db_manager: DatabaseManager instance. Uses default if not provided.
paths: PathConfig for reference data. Uses default if not provided.
batch_size: Number of rows to insert per batch (default: 5000).
force: If True, re-process even if file hash matches.
progress_callback: Optional callback(rows_inserted, total_rows) for progress.
Returns:
PatientDataLoadResult with loading statistics.
"""
if db_manager is None:
db_manager = DatabaseManager()
if paths is None:
paths = default_paths
file_path = Path(file_path)
file_path_str = str(file_path.absolute())
logger.info(f"Starting patient data load from {file_path}")
start_time = time.time()
# Check file exists
if not file_path.exists():
error_msg = f"File not found: {file_path}"
logger.error(error_msg)
return PatientDataLoadResult(
file_path=file_path_str,
file_hash="",
rows_read=0,
rows_inserted=0,
rows_skipped=0,
success=False,
error_message=error_msg
)
# Calculate file hash
logger.info("Calculating file hash...")
file_hash = calculate_file_hash(file_path)
file_size = file_path.stat().st_size
file_modified = datetime.fromtimestamp(file_path.stat().st_mtime)
logger.info(f"File hash: {file_hash[:16]}... Size: {file_size:,} bytes")
# Check if already processed
if not force:
with db_manager.get_connection() as conn:
is_processed, old_hash = check_file_processed(conn, file_path_str, file_hash)
if is_processed:
logger.info(f"File already processed with same hash, skipping")
return PatientDataLoadResult(
file_path=file_path_str,
file_hash=file_hash,
rows_read=0,
rows_inserted=0,
rows_skipped=0,
success=True,
was_already_processed=True
)
elif old_hash is not None:
logger.info(f"File hash changed, will re-process (old: {old_hash[:16]}...)")
try:
# Use FileDataLoader to load and transform data
from data_processing.loader import FileDataLoader
loader = FileDataLoader(file_path, paths)
logger.info("Loading and transforming data...")
result = loader.load()
df = result.df
rows_read = result.row_count
logger.info(f"Loaded {rows_read:,} rows, starting SQLite insert...")
# Load into SQLite
with db_manager.get_transaction() as conn:
# Record that we're starting
record_file_processing_start(conn, file_path_str, file_hash, file_size, file_modified)
# Delete any existing data from this file (for re-processing)
deleted = delete_file_data(conn, file_path_str)
if deleted > 0:
logger.info(f"Deleted {deleted:,} existing rows from previous load")
# Insert new data
rows_inserted = load_dataframe_to_sqlite(
df, conn, file_path_str, batch_size, progress_callback
)
# Record success
load_time = time.time() - start_time
record_file_processing_complete(
conn, file_path_str, rows_inserted, load_time, True
)
logger.info(f"Successfully loaded {rows_inserted:,} rows in {load_time:.1f}s")
return PatientDataLoadResult(
file_path=file_path_str,
file_hash=file_hash,
rows_read=rows_read,
rows_inserted=rows_inserted,
rows_skipped=rows_read - rows_inserted,
success=True,
load_time_seconds=load_time
)
except Exception as e:
load_time = time.time() - start_time
error_msg = str(e)
logger.error(f"Failed to load patient data: {error_msg}")
# Record failure
try:
with db_manager.get_connection() as conn:
record_file_processing_complete(
conn, file_path_str, 0, load_time, False, error_msg
)
except Exception:
pass # Don't fail on failure to record failure
return PatientDataLoadResult(
file_path=file_path_str,
file_hash=file_hash if 'file_hash' in dir() else "",
rows_read=0,
rows_inserted=0,
rows_skipped=0,
success=False,
error_message=error_msg,
load_time_seconds=load_time
)
def get_patient_data_stats(db_manager: Optional[DatabaseManager] = None) -> dict:
"""
Get statistics about patient data in fact_interventions.
Returns:
Dictionary with statistics about the loaded data.
"""
if db_manager is None:
db_manager = DatabaseManager()
stats = {}
with db_manager.get_connection() as conn:
# Total rows
cursor = conn.execute("SELECT COUNT(*) FROM fact_interventions")
stats["total_rows"] = cursor.fetchone()[0]
# Unique patients
cursor = conn.execute("SELECT COUNT(DISTINCT upid) FROM fact_interventions")
stats["unique_patients"] = cursor.fetchone()[0]
# Unique drugs
cursor = conn.execute("SELECT COUNT(DISTINCT drug_name_std) FROM fact_interventions")
stats["unique_drugs"] = cursor.fetchone()[0]
# Unique organizations
cursor = conn.execute("SELECT COUNT(DISTINCT org_name) FROM fact_interventions")
stats["unique_organizations"] = cursor.fetchone()[0]
# Date range
cursor = conn.execute("""
SELECT MIN(intervention_date), MAX(intervention_date)
FROM fact_interventions
""")
result = cursor.fetchone()
stats["date_range"] = (result[0], result[1]) if result else (None, None)
# Processed files
cursor = conn.execute("""
SELECT COUNT(*), SUM(row_count)
FROM processed_files WHERE status = 'success'
""")
result = cursor.fetchone()
stats["processed_files"] = result[0] if result else 0
stats["processed_rows"] = result[1] if result and result[1] else 0
return stats
def list_processed_files(db_manager: Optional[DatabaseManager] = None) -> list[dict]:
"""
List all processed files and their status.
Returns:
List of dictionaries with file processing information.
"""
if db_manager is None:
db_manager = DatabaseManager()
files = []
with db_manager.get_connection() as conn:
cursor = conn.execute("""
SELECT file_path, file_name, file_hash, file_size_bytes,
row_count, status, error_message,
first_processed_at, last_processed_at, processing_duration_seconds
FROM processed_files
ORDER BY last_processed_at DESC
""")
for row in cursor.fetchall():
files.append({
"file_path": row["file_path"],
"file_name": row["file_name"],
"file_hash": row["file_hash"],
"file_size_bytes": row["file_size_bytes"],
"row_count": row["row_count"],
"status": row["status"],
"error_message": row["error_message"],
"first_processed_at": row["first_processed_at"],
"last_processed_at": row["last_processed_at"],
"processing_duration_seconds": row["processing_duration_seconds"],
})
return files
# =============================================================================
# Materialized View Refresh Functions
# =============================================================================
@dataclass
class MVRefreshResult:
"""Results from refreshing the patient treatment summary materialized view."""
patients_processed: int
rows_inserted: int
refresh_time_seconds: float
success: bool
error_message: Optional[str] = None
def __str__(self) -> str:
if self.success:
return (
f"Refreshed MV: {self.patients_processed:,} patients "
f"in {self.refresh_time_seconds:.1f}s"
)
else:
return f"MV refresh FAILED: {self.error_message}"
def refresh_patient_treatment_summary(
db_manager: Optional[DatabaseManager] = None,
progress_callback: Optional[Callable[[int, int], None]] = None
) -> MVRefreshResult:
"""
Refresh the mv_patient_treatment_summary materialized view.
This computes per-patient aggregations from fact_interventions:
- First/last seen dates
- Total cost, average cost per intervention
- Intervention count, unique drug count
- Drug sequence (chronological, pipe-separated)
- Drug counts, costs, and date ranges (as JSON)
The MV is fully rebuilt (truncate and re-insert) for simplicity.
This typically takes 30-60 seconds for ~35,000 patients.
Args:
db_manager: DatabaseManager instance. Uses default if not provided.
progress_callback: Optional callback(patients_done, total_patients).
Returns:
MVRefreshResult with refresh statistics.
"""
if db_manager is None:
db_manager = DatabaseManager()
logger.info("Starting materialized view refresh...")
start_time = time.time()
try:
with db_manager.get_transaction() as conn:
# Step 1: Get total patient count for progress reporting
cursor = conn.execute("SELECT COUNT(DISTINCT upid) FROM fact_interventions")
total_patients = cursor.fetchone()[0]
logger.info(f"Processing {total_patients:,} unique patients")
if total_patients == 0:
logger.warning("No patient data in fact_interventions, MV will be empty")
return MVRefreshResult(
patients_processed=0,
rows_inserted=0,
refresh_time_seconds=time.time() - start_time,
success=True
)
# Step 2: Clear existing MV data
conn.execute("DELETE FROM mv_patient_treatment_summary")
logger.info("Cleared existing MV data")
# Step 3: Compute aggregations using SQL CTEs
# This is more efficient than processing row-by-row in Python
refresh_sql = """
WITH patient_aggs AS (
-- Basic aggregations per patient
SELECT
upid,
MIN(org_name) as org_name,
MIN(directory) as directory,
MIN(intervention_date) as first_seen_date,
MAX(intervention_date) as last_seen_date,
JULIANDAY(MAX(intervention_date)) - JULIANDAY(MIN(intervention_date)) as days_treated,
SUM(price_actual) as total_cost,
AVG(price_actual) as avg_cost_per_intervention,
COUNT(*) as intervention_count,
COUNT(DISTINCT drug_name_std) as unique_drug_count,
COUNT(*) as source_row_count
FROM fact_interventions
GROUP BY upid
),
drug_sequences AS (
-- Drug sequence per patient (chronological order, pipe-separated)
SELECT
upid,
GROUP_CONCAT(drug_name_std, '|') as drug_sequence
FROM (
SELECT DISTINCT
upid,
drug_name_std,
MIN(intervention_date) as first_date
FROM fact_interventions
GROUP BY upid, drug_name_std
ORDER BY upid, first_date
)
GROUP BY upid
),
drug_counts AS (
-- JSON object of drug counts per patient
SELECT
upid,
'{' || GROUP_CONCAT('"' || drug_name_std || '": ' || cnt, ', ') || '}' as drug_counts_json
FROM (
SELECT
upid,
drug_name_std,
COUNT(*) as cnt
FROM fact_interventions
GROUP BY upid, drug_name_std
)
GROUP BY upid
),
drug_costs AS (
-- JSON object of drug costs per patient
SELECT
upid,
'{' || GROUP_CONCAT('"' || drug_name_std || '": ' || ROUND(total_cost, 2), ', ') || '}' as drug_costs_json
FROM (
SELECT
upid,
drug_name_std,
SUM(price_actual) as total_cost
FROM fact_interventions
GROUP BY upid, drug_name_std
)
GROUP BY upid
),
drug_dates AS (
-- JSON object of drug date ranges per patient
SELECT
upid,
'{' || GROUP_CONCAT('"' || drug_name_std || '": {"first": "' || first_date || '", "last": "' || last_date || '"}', ', ') || '}' as drug_date_ranges_json
FROM (
SELECT
upid,
drug_name_std,
MIN(intervention_date) as first_date,
MAX(intervention_date) as last_date
FROM fact_interventions
GROUP BY upid, drug_name_std
)
GROUP BY upid
)
INSERT INTO mv_patient_treatment_summary (
upid, org_name, directory,
first_seen_date, last_seen_date, days_treated,
total_cost, avg_cost_per_intervention,
intervention_count, unique_drug_count,
drug_sequence, drug_counts_json, drug_costs_json, drug_date_ranges_json,
source_row_count, computed_at
)
SELECT
pa.upid,
pa.org_name,
pa.directory,
pa.first_seen_date,
pa.last_seen_date,
CAST(pa.days_treated AS INTEGER),
pa.total_cost,
pa.avg_cost_per_intervention,
pa.intervention_count,
pa.unique_drug_count,
ds.drug_sequence,
dc.drug_counts_json,
dco.drug_costs_json,
dd.drug_date_ranges_json,
pa.source_row_count,
CURRENT_TIMESTAMP
FROM patient_aggs pa
LEFT JOIN drug_sequences ds ON pa.upid = ds.upid
LEFT JOIN drug_counts dc ON pa.upid = dc.upid
LEFT JOIN drug_costs dco ON pa.upid = dco.upid
LEFT JOIN drug_dates dd ON pa.upid = dd.upid
"""
logger.info("Executing MV refresh query...")
conn.execute(refresh_sql)
# Get actual rows inserted
cursor = conn.execute("SELECT COUNT(*) FROM mv_patient_treatment_summary")
rows_inserted = cursor.fetchone()[0]
refresh_time = time.time() - start_time
logger.info(f"MV refresh complete: {rows_inserted:,} rows in {refresh_time:.1f}s")
# Report progress if callback provided
if progress_callback:
progress_callback(rows_inserted, total_patients)
return MVRefreshResult(
patients_processed=total_patients,
rows_inserted=rows_inserted,
refresh_time_seconds=refresh_time,
success=True
)
except Exception as e:
refresh_time = time.time() - start_time
error_msg = str(e)
logger.error(f"MV refresh failed: {error_msg}")
return MVRefreshResult(
patients_processed=0,
rows_inserted=0,
refresh_time_seconds=refresh_time,
success=False,
error_message=error_msg
)
def get_patient_summary_stats(db_manager: Optional[DatabaseManager] = None) -> dict:
"""
Get statistics about the patient treatment summary MV.
Returns:
Dictionary with MV statistics.
"""
if db_manager is None:
db_manager = DatabaseManager()
stats = {}
with db_manager.get_connection() as conn:
# Total rows
cursor = conn.execute("SELECT COUNT(*) FROM mv_patient_treatment_summary")
stats["total_patients"] = cursor.fetchone()[0]
if stats["total_patients"] == 0:
return stats
# Aggregated statistics
cursor = conn.execute("""
SELECT
SUM(total_cost) as total_cost_all,
AVG(total_cost) as avg_cost_per_patient,
SUM(intervention_count) as total_interventions,
AVG(intervention_count) as avg_interventions_per_patient,
AVG(unique_drug_count) as avg_drugs_per_patient,
AVG(days_treated) as avg_days_treated,
MIN(first_seen_date) as earliest_date,
MAX(last_seen_date) as latest_date,
MAX(computed_at) as last_refresh
FROM mv_patient_treatment_summary
""")
result = cursor.fetchone()
stats["total_cost"] = result[0] if result[0] else 0
stats["avg_cost_per_patient"] = result[1] if result[1] else 0
stats["total_interventions"] = result[2] if result[2] else 0
stats["avg_interventions_per_patient"] = result[3] if result[3] else 0
stats["avg_drugs_per_patient"] = result[4] if result[4] else 0
stats["avg_days_treated"] = result[5] if result[5] else 0
stats["date_range"] = (result[6], result[7])
stats["last_refresh"] = result[8]
# Unique directories in MV
cursor = conn.execute("SELECT COUNT(DISTINCT directory) FROM mv_patient_treatment_summary")
stats["unique_directories"] = cursor.fetchone()[0]
# Unique organizations in MV
cursor = conn.execute("SELECT COUNT(DISTINCT org_name) FROM mv_patient_treatment_summary")
stats["unique_organizations"] = cursor.fetchone()[0]
return stats
def verify_mv_consistency(db_manager: Optional[DatabaseManager] = None) -> tuple[bool, str]:
"""
Verify that the MV is consistent with fact_interventions.
Checks that:
- Patient counts match
- Total cost sums match
- Intervention counts match
Returns:
Tuple of (is_consistent, message).
"""
if db_manager is None:
db_manager = DatabaseManager()
with db_manager.get_connection() as conn:
# Get fact table counts
cursor = conn.execute("""
SELECT
COUNT(DISTINCT upid) as patients,
SUM(price_actual) as total_cost,
COUNT(*) as interventions
FROM fact_interventions
""")
fact_row = cursor.fetchone()
fact_patients = fact_row[0] or 0
fact_cost = fact_row[1] or 0
fact_interventions = fact_row[2] or 0
# Get MV counts
cursor = conn.execute("""
SELECT
COUNT(*) as patients,
SUM(total_cost) as total_cost,
SUM(intervention_count) as interventions
FROM mv_patient_treatment_summary
""")
mv_row = cursor.fetchone()
mv_patients = mv_row[0] or 0
mv_cost = mv_row[1] or 0
mv_interventions = mv_row[2] or 0
# Compare
issues = []
if fact_patients != mv_patients:
issues.append(f"Patient count mismatch: fact={fact_patients:,}, mv={mv_patients:,}")
if mv_interventions != fact_interventions:
issues.append(f"Intervention count mismatch: fact={fact_interventions:,}, mv={mv_interventions:,}")
# Allow small floating point differences in cost
cost_diff = abs(fact_cost - mv_cost)
if cost_diff > 0.01:
issues.append(f"Cost mismatch: fact={fact_cost:,.2f}, mv={mv_cost:,.2f}, diff={cost_diff:.2f}")
if issues:
return False, "; ".join(issues)
return True, f"MV consistent: {mv_patients:,} patients, {mv_interventions:,} interventions, £{mv_cost:,.2f} total"
File diff suppressed because it is too large Load Diff
+665
View File
@@ -0,0 +1,665 @@
"""
SQLite schema definitions for NHS High-Cost Drug Patient Pathway Analysis Tool.
Contains SQL strings for creating reference tables, fact tables, and indexes.
Schema design supports:
- Reference data from CSV files (drug names, organizations, directories)
- Drug-directory mappings with single-valid-directory flag
- Patient intervention facts with proper indexing
- Cached aggregations for performance
- File tracking for incremental updates
"""
from typing import Optional
import sqlite3
from core.logging_config import get_logger
logger = get_logger(__name__)
# =============================================================================
# Reference Table Schemas
# =============================================================================
REF_DRUG_NAMES_SCHEMA = """
-- Mapping from raw drug names (as they appear in source data) to standardized names
-- Source: data/drugnames.csv
CREATE TABLE IF NOT EXISTS ref_drug_names (
id INTEGER PRIMARY KEY AUTOINCREMENT,
raw_name TEXT NOT NULL UNIQUE,
standard_name TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
-- Index for fast lookups during data transformation
CREATE INDEX IF NOT EXISTS idx_ref_drug_names_raw ON ref_drug_names(raw_name);
CREATE INDEX IF NOT EXISTS idx_ref_drug_names_standard ON ref_drug_names(standard_name);
"""
REF_ORGANIZATIONS_SCHEMA = """
-- NHS organization codes and names
-- Source: data/org_codes.csv
CREATE TABLE IF NOT EXISTS ref_organizations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
org_code TEXT NOT NULL UNIQUE,
org_name TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
-- Index for fast lookups by organization code
CREATE INDEX IF NOT EXISTS idx_ref_organizations_code ON ref_organizations(org_code);
"""
REF_DIRECTORIES_SCHEMA = """
-- Medical directories/specialties
-- Source: data/directory_list.csv
CREATE TABLE IF NOT EXISTS ref_directories (
id INTEGER PRIMARY KEY AUTOINCREMENT,
directory_name TEXT NOT NULL UNIQUE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
-- Index for fast lookups by directory name
CREATE INDEX IF NOT EXISTS idx_ref_directories_name ON ref_directories(directory_name);
"""
REF_DRUG_DIRECTORY_MAP_SCHEMA = """
-- Mapping from drug names to valid directories
-- Source: data/drug_directory_list.csv
-- A drug may map to multiple directories (one row per drug-directory pair)
-- The is_single_valid flag indicates drugs with exactly ONE valid directory,
-- which enables automatic directory assignment in department_identification()
CREATE TABLE IF NOT EXISTS ref_drug_directory_map (
id INTEGER PRIMARY KEY AUTOINCREMENT,
drug_name TEXT NOT NULL,
directory_name TEXT NOT NULL,
is_single_valid BOOLEAN NOT NULL DEFAULT 0,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE(drug_name, directory_name)
);
-- Index for looking up directories by drug name (most common access pattern)
CREATE INDEX IF NOT EXISTS idx_ref_drug_directory_map_drug ON ref_drug_directory_map(drug_name);
-- Index for reverse lookup (find drugs by directory)
CREATE INDEX IF NOT EXISTS idx_ref_drug_directory_map_directory ON ref_drug_directory_map(directory_name);
-- Index for quick filtering of single-valid drugs
CREATE INDEX IF NOT EXISTS idx_ref_drug_directory_map_single ON ref_drug_directory_map(is_single_valid);
"""
REF_DRUG_INDICATION_CLUSTERS_SCHEMA = """
-- Mapping from drugs to SNOMED clusters for indication validation
-- Source: data/drug_indication_clusters.csv
-- Used to validate that patients have appropriate GP diagnoses for their prescribed drugs
-- A drug may map to multiple clusters (one row per drug-indication-cluster combination)
CREATE TABLE IF NOT EXISTS ref_drug_indication_clusters (
id INTEGER PRIMARY KEY AUTOINCREMENT,
drug_name TEXT NOT NULL,
indication TEXT NOT NULL,
cluster_id TEXT NOT NULL,
cluster_description TEXT,
nice_ta_reference TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE(drug_name, indication, cluster_id)
);
-- Index for looking up clusters by drug name (most common access pattern)
CREATE INDEX IF NOT EXISTS idx_ref_drug_indication_clusters_drug ON ref_drug_indication_clusters(drug_name);
-- Index for looking up drugs by cluster (for finding all drugs treating a condition)
CREATE INDEX IF NOT EXISTS idx_ref_drug_indication_clusters_cluster ON ref_drug_indication_clusters(cluster_id);
-- Index for looking up by indication text
CREATE INDEX IF NOT EXISTS idx_ref_drug_indication_clusters_indication ON ref_drug_indication_clusters(indication);
"""
# =============================================================================
# Fact Table Schemas
# =============================================================================
FACT_INTERVENTIONS_SCHEMA = """
-- Patient intervention records (fact table)
-- Source: HCD activity data (CSV/Parquet files or Snowflake)
-- This is the main fact table storing all patient intervention events
CREATE TABLE IF NOT EXISTS fact_interventions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
-- Patient identification
upid TEXT NOT NULL, -- Unique Patient ID (Provider Code[:3] + PersonKey)
provider_code TEXT NOT NULL, -- Original provider code (3-5 chars)
person_key TEXT NOT NULL, -- Patient key from source system
-- Intervention details
drug_name_raw TEXT, -- Original drug name from source
drug_name_std TEXT NOT NULL, -- Standardized drug name (via ref_drug_names)
intervention_date DATE NOT NULL, -- Date of intervention
price_actual REAL NOT NULL DEFAULT 0, -- Cost of intervention in GBP
-- Organization and directory
org_name TEXT, -- Organization name (cleaned, no commas)
directory TEXT, -- Medical directory/specialty (may be "Undefined")
-- Source tracking
source_file TEXT, -- Original file this record came from
loaded_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
-- Additional clinical fields (optional, used in directory fallback logic)
treatment_function_code INTEGER,
additional_detail_1 TEXT,
additional_detail_2 TEXT,
additional_detail_3 TEXT,
additional_detail_4 TEXT,
additional_detail_5 TEXT
);
-- Primary indexes for common filter patterns used in generate_graph()
-- UPID: Used for patient grouping, pathway analysis
CREATE INDEX IF NOT EXISTS idx_fact_interventions_upid ON fact_interventions(upid);
-- Drug name (standardized): Used for drug filtering
CREATE INDEX IF NOT EXISTS idx_fact_interventions_drug ON fact_interventions(drug_name_std);
-- Intervention date: Used for date range filtering (start_date, end_date, last_seen)
CREATE INDEX IF NOT EXISTS idx_fact_interventions_date ON fact_interventions(intervention_date);
-- Directory: Used for directory/specialty filtering
CREATE INDEX IF NOT EXISTS idx_fact_interventions_directory ON fact_interventions(directory);
-- Organization: Used for trust filtering (Provider Code maps to org_name)
CREATE INDEX IF NOT EXISTS idx_fact_interventions_org ON fact_interventions(org_name);
-- Composite index for common filter combination (trust + drug + directory)
CREATE INDEX IF NOT EXISTS idx_fact_interventions_composite
ON fact_interventions(org_name, drug_name_std, directory);
-- Composite index for date-based patient analysis
CREATE INDEX IF NOT EXISTS idx_fact_interventions_upid_date
ON fact_interventions(upid, intervention_date);
"""
# =============================================================================
# Materialized View Schemas (Cached Aggregations)
# =============================================================================
MV_PATIENT_TREATMENT_SUMMARY_SCHEMA = """
-- Materialized view of patient treatment summaries
-- Pre-computed aggregations per patient for faster pathway analysis
-- Refreshed when fact_interventions data changes
CREATE TABLE IF NOT EXISTS mv_patient_treatment_summary (
id INTEGER PRIMARY KEY AUTOINCREMENT,
-- Patient identification
upid TEXT NOT NULL UNIQUE, -- Unique Patient ID
-- Organization and directory (for filtering)
org_name TEXT, -- Organization name (first org seen)
directory TEXT, -- Primary directory (first directory assigned)
-- Date range
first_seen_date DATE NOT NULL, -- First intervention date
last_seen_date DATE NOT NULL, -- Last intervention date
days_treated INTEGER NOT NULL DEFAULT 0, -- Duration: last_seen - first_seen
-- Cost aggregations
total_cost REAL NOT NULL DEFAULT 0, -- Sum of all intervention costs
avg_cost_per_intervention REAL, -- Average cost per intervention
-- Treatment summary
intervention_count INTEGER NOT NULL DEFAULT 0, -- Total number of interventions
unique_drug_count INTEGER NOT NULL DEFAULT 0, -- Number of distinct drugs
-- Drug sequence (pipe-separated standardized drug names in chronological order)
-- Example: "ADALIMUMAB|ETANERCEPT|INFLIXIMAB"
drug_sequence TEXT,
-- Drug frequency counts (JSON: {"ADALIMUMAB": 5, "ETANERCEPT": 3})
-- Stores count of each drug for this patient
drug_counts_json TEXT,
-- Drug cost totals (JSON: {"ADALIMUMAB": 15000.00, "ETANERCEPT": 8000.00})
-- Stores total cost per drug for this patient
drug_costs_json TEXT,
-- Per-drug date ranges (JSON: {"ADALIMUMAB": {"first": "2023-01-01", "last": "2023-06-15"}, ...})
-- Stores first/last date for each drug
drug_date_ranges_json TEXT,
-- Metadata
computed_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
source_row_count INTEGER -- Number of fact_interventions rows used
);
-- Index for fast patient lookup
CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_upid ON mv_patient_treatment_summary(upid);
-- Indexes for common filter patterns
CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_org ON mv_patient_treatment_summary(org_name);
CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_directory ON mv_patient_treatment_summary(directory);
CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_first_seen ON mv_patient_treatment_summary(first_seen_date);
CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_last_seen ON mv_patient_treatment_summary(last_seen_date);
-- Composite index for date range filtering (common in generate_graph)
CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_date_range
ON mv_patient_treatment_summary(first_seen_date, last_seen_date);
-- Composite index for org + directory + dates (full filter pattern)
CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_filter_composite
ON mv_patient_treatment_summary(org_name, directory, first_seen_date, last_seen_date);
-- Index for drug sequence pattern matching
CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_drug_seq ON mv_patient_treatment_summary(drug_sequence);
"""
MATERIALIZED_VIEWS_SCHEMA = f"""
-- Materialized Views Schema
-- Pre-computed aggregations for performance
{MV_PATIENT_TREATMENT_SUMMARY_SCHEMA}
"""
# =============================================================================
# File Tracking Schemas (Incremental Updates)
# =============================================================================
PROCESSED_FILES_SCHEMA = """
-- Tracks processed data files for incremental updates
-- Enables detecting changed files by comparing hashes
-- Stores processing status and statistics
CREATE TABLE IF NOT EXISTS processed_files (
id INTEGER PRIMARY KEY AUTOINCREMENT,
-- File identification
file_path TEXT NOT NULL, -- Full path to the file
file_name TEXT NOT NULL, -- Just the filename (for display)
file_hash TEXT NOT NULL, -- SHA256 hash of file contents
-- File metadata
file_size_bytes INTEGER, -- Size of file in bytes
file_modified_at TIMESTAMP, -- File's last modification timestamp
-- Processing results
row_count INTEGER DEFAULT 0, -- Number of rows processed from this file
status TEXT NOT NULL DEFAULT 'pending', -- pending, processing, success, error
error_message TEXT, -- Error details if status='error'
-- Timestamps
first_processed_at TIMESTAMP, -- When first processed
last_processed_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
processing_duration_seconds REAL, -- How long processing took
-- Uniqueness: only one record per file path
-- Hash changes indicate file content changed (needs reprocessing)
UNIQUE(file_path)
);
-- Index for fast lookup by file path
CREATE INDEX IF NOT EXISTS idx_processed_files_path ON processed_files(file_path);
-- Index for finding files by status (e.g., find all pending or errored files)
CREATE INDEX IF NOT EXISTS idx_processed_files_status ON processed_files(status);
-- Index for finding files by hash (detect if same file appears at different paths)
CREATE INDEX IF NOT EXISTS idx_processed_files_hash ON processed_files(file_hash);
-- Index for finding recently processed files
CREATE INDEX IF NOT EXISTS idx_processed_files_last_processed ON processed_files(last_processed_at);
"""
FILE_TRACKING_SCHEMA = f"""
-- File Tracking Schema
-- Supports incremental data loading
{PROCESSED_FILES_SCHEMA}
"""
# =============================================================================
# Combined Schemas
# =============================================================================
REFERENCE_TABLES_SCHEMA = f"""
-- Reference Tables Schema
-- Contains lookup data migrated from CSV files
{REF_DRUG_NAMES_SCHEMA}
{REF_ORGANIZATIONS_SCHEMA}
{REF_DIRECTORIES_SCHEMA}
{REF_DRUG_DIRECTORY_MAP_SCHEMA}
{REF_DRUG_INDICATION_CLUSTERS_SCHEMA}
"""
FACT_TABLES_SCHEMA = f"""
-- Fact Tables Schema
-- Contains patient intervention data
{FACT_INTERVENTIONS_SCHEMA}
"""
ALL_TABLES_SCHEMA = f"""
-- Complete Database Schema
-- Reference tables + Fact tables + Materialized views + File tracking
{REFERENCE_TABLES_SCHEMA}
{FACT_TABLES_SCHEMA}
{MATERIALIZED_VIEWS_SCHEMA}
{FILE_TRACKING_SCHEMA}
"""
# =============================================================================
# Schema Helper Functions
# =============================================================================
def create_reference_tables(conn: sqlite3.Connection) -> None:
"""
Create all reference tables in the database.
Args:
conn: SQLite database connection.
"""
logger.info("Creating reference tables...")
conn.executescript(REFERENCE_TABLES_SCHEMA)
logger.info("Reference tables created successfully")
def drop_reference_tables(conn: sqlite3.Connection) -> None:
"""
Drop all reference tables from the database.
Args:
conn: SQLite database connection.
Warning:
This will delete all reference data. Use with caution.
"""
logger.warning("Dropping reference tables...")
conn.executescript("""
DROP TABLE IF EXISTS ref_drug_names;
DROP TABLE IF EXISTS ref_organizations;
DROP TABLE IF EXISTS ref_directories;
DROP TABLE IF EXISTS ref_drug_directory_map;
DROP TABLE IF EXISTS ref_drug_indication_clusters;
""")
logger.info("Reference tables dropped")
def get_reference_table_counts(conn: sqlite3.Connection) -> dict[str, int]:
"""
Get row counts for all reference tables.
Args:
conn: SQLite database connection.
Returns:
Dictionary mapping table name to row count.
"""
tables = ["ref_drug_names", "ref_organizations", "ref_directories", "ref_drug_directory_map", "ref_drug_indication_clusters"]
counts = {}
for table in tables:
cursor = conn.execute(f"SELECT COUNT(*) FROM {table}")
result = cursor.fetchone()
counts[table] = result[0] if result else 0
return counts
def verify_reference_tables_exist(conn: sqlite3.Connection) -> list[str]:
"""
Verify that all reference tables exist.
Args:
conn: SQLite database connection.
Returns:
List of missing table names. Empty list means all tables exist.
"""
required_tables = ["ref_drug_names", "ref_organizations", "ref_directories", "ref_drug_directory_map", "ref_drug_indication_clusters"]
missing = []
for table in required_tables:
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
(table,)
)
if cursor.fetchone() is None:
missing.append(table)
return missing
# =============================================================================
# Fact Table Helper Functions
# =============================================================================
def create_fact_tables(conn: sqlite3.Connection) -> None:
"""
Create all fact tables in the database (including materialized views).
Args:
conn: SQLite database connection.
"""
logger.info("Creating fact tables...")
conn.executescript(FACT_TABLES_SCHEMA)
conn.executescript(MATERIALIZED_VIEWS_SCHEMA)
logger.info("Fact tables created successfully")
def drop_fact_tables(conn: sqlite3.Connection) -> None:
"""
Drop all fact tables from the database.
Args:
conn: SQLite database connection.
Warning:
This will delete all patient intervention data. Use with caution.
"""
logger.warning("Dropping fact tables...")
conn.executescript("""
DROP TABLE IF EXISTS fact_interventions;
DROP TABLE IF EXISTS mv_patient_treatment_summary;
""")
logger.info("Fact tables dropped")
def get_fact_table_counts(conn: sqlite3.Connection) -> dict[str, int]:
"""
Get row counts for all fact tables (including materialized views).
Args:
conn: SQLite database connection.
Returns:
Dictionary mapping table name to row count.
"""
tables = ["fact_interventions", "mv_patient_treatment_summary"]
counts = {}
for table in tables:
cursor = conn.execute(f"SELECT COUNT(*) FROM {table}")
result = cursor.fetchone()
counts[table] = result[0] if result else 0
return counts
def verify_fact_tables_exist(conn: sqlite3.Connection) -> list[str]:
"""
Verify that all fact tables exist (including materialized views).
Args:
conn: SQLite database connection.
Returns:
List of missing table names. Empty list means all tables exist.
"""
required_tables = ["fact_interventions", "mv_patient_treatment_summary"]
missing = []
for table in required_tables:
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
(table,)
)
if cursor.fetchone() is None:
missing.append(table)
return missing
# =============================================================================
# File Tracking Helper Functions
# =============================================================================
def create_file_tracking_tables(conn: sqlite3.Connection) -> None:
"""
Create file tracking tables in the database.
Args:
conn: SQLite database connection.
"""
logger.info("Creating file tracking tables...")
conn.executescript(FILE_TRACKING_SCHEMA)
logger.info("File tracking tables created successfully")
def drop_file_tracking_tables(conn: sqlite3.Connection) -> None:
"""
Drop file tracking tables from the database.
Args:
conn: SQLite database connection.
Warning:
This will delete all file tracking history.
"""
logger.warning("Dropping file tracking tables...")
conn.executescript("""
DROP TABLE IF EXISTS processed_files;
""")
logger.info("File tracking tables dropped")
def get_file_tracking_counts(conn: sqlite3.Connection) -> dict[str, int]:
"""
Get row counts for file tracking tables.
Args:
conn: SQLite database connection.
Returns:
Dictionary mapping table name to row count.
"""
tables = ["processed_files"]
counts = {}
for table in tables:
cursor = conn.execute(f"SELECT COUNT(*) FROM {table}")
result = cursor.fetchone()
counts[table] = result[0] if result else 0
return counts
def verify_file_tracking_tables_exist(conn: sqlite3.Connection) -> list[str]:
"""
Verify that file tracking tables exist.
Args:
conn: SQLite database connection.
Returns:
List of missing table names. Empty list means all tables exist.
"""
required_tables = ["processed_files"]
missing = []
for table in required_tables:
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
(table,)
)
if cursor.fetchone() is None:
missing.append(table)
return missing
# =============================================================================
# Combined Helper Functions
# =============================================================================
def create_all_tables(conn: sqlite3.Connection) -> None:
"""
Create all tables (reference + fact) in the database.
Args:
conn: SQLite database connection.
"""
logger.info("Creating all database tables...")
conn.executescript(ALL_TABLES_SCHEMA)
logger.info("All tables created successfully")
def drop_all_tables(conn: sqlite3.Connection) -> None:
"""
Drop all tables from the database.
Args:
conn: SQLite database connection.
Warning:
This will delete all data. Use with extreme caution.
"""
logger.warning("Dropping all tables...")
drop_file_tracking_tables(conn)
drop_fact_tables(conn)
drop_reference_tables(conn)
logger.info("All tables dropped")
def get_all_table_counts(conn: sqlite3.Connection) -> dict[str, int]:
"""
Get row counts for all tables.
Args:
conn: SQLite database connection.
Returns:
Dictionary mapping table name to row count.
"""
counts = {}
counts.update(get_reference_table_counts(conn))
counts.update(get_fact_table_counts(conn))
counts.update(get_file_tracking_counts(conn))
return counts
def verify_all_tables_exist(conn: sqlite3.Connection) -> list[str]:
"""
Verify that all tables exist.
Args:
conn: SQLite database connection.
Returns:
List of missing table names. Empty list means all tables exist.
"""
missing = []
missing.extend(verify_reference_tables_exist(conn))
missing.extend(verify_fact_tables_exist(conn))
missing.extend(verify_file_tracking_tables_exist(conn))
return missing
+797
View File
@@ -0,0 +1,797 @@
"""
Snowflake connector module for NHS Patient Pathway Analysis.
Provides connection handling with SSO browser authentication for NHS environments.
Uses the externalbrowser authenticator which opens a browser window for NHS identity
management authentication.
Usage:
from data_processing.snowflake_connector import SnowflakeConnector, get_connector
# Using context manager (recommended)
with get_connector() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM table LIMIT 10")
results = cursor.fetchall()
# Manual connection management
connector = SnowflakeConnector()
try:
conn = connector.connect()
cursor = conn.cursor()
# ... use cursor ...
finally:
connector.close()
"""
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import date, datetime
from pathlib import Path
from typing import Any, Generator, Optional, TYPE_CHECKING
import time
# Snowflake connector is an optional dependency
SNOWFLAKE_AVAILABLE = False
try:
import snowflake.connector
from snowflake.connector import SnowflakeConnection
from snowflake.connector.cursor import SnowflakeCursor
SNOWFLAKE_AVAILABLE = True
except ImportError:
snowflake = None # type: ignore[assignment]
# Type hints for when snowflake is not available
if TYPE_CHECKING:
from snowflake.connector import SnowflakeConnection
from snowflake.connector.cursor import SnowflakeCursor
from config import get_snowflake_config, SnowflakeConfig
from core.logging_config import get_logger
logger = get_logger(__name__)
class SnowflakeConnectionError(Exception):
"""Raised when Snowflake connection fails."""
pass
class SnowflakeNotConfiguredError(Exception):
"""Raised when Snowflake is not configured (no account)."""
pass
class SnowflakeNotAvailableError(Exception):
"""Raised when snowflake-connector-python is not installed."""
pass
@dataclass
class ConnectionInfo:
"""Information about the current connection state."""
connected: bool = False
account: str = ""
warehouse: str = ""
database: str = ""
schema: str = ""
user: str = ""
role: str = ""
connected_at: Optional[datetime] = None
last_query_at: Optional[datetime] = None
query_count: int = 0
class SnowflakeConnector:
"""
Manages Snowflake connections with SSO browser authentication.
This class provides connection management for NHS Snowflake access using
the externalbrowser authenticator which triggers NHS SSO login via browser.
Attributes:
config: SnowflakeConfig with connection settings
connection_info: ConnectionInfo tracking current state
Example:
connector = SnowflakeConnector()
with connector.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT CURRENT_USER()")
print(cursor.fetchone()[0])
"""
def __init__(self, config: Optional[SnowflakeConfig] = None):
"""
Initialize the connector with configuration.
Args:
config: Optional SnowflakeConfig. If not provided, loads from
config/snowflake.toml using get_snowflake_config().
"""
self._config = config or get_snowflake_config()
self._connection: Optional[SnowflakeConnection] = None
self._connection_info = ConnectionInfo()
@property
def config(self) -> SnowflakeConfig:
"""Return the Snowflake configuration."""
return self._config
@property
def connection_info(self) -> ConnectionInfo:
"""Return information about the current connection state."""
return self._connection_info
@property
def is_connected(self) -> bool:
"""Return True if currently connected to Snowflake."""
return self._connection is not None and not self._connection.is_closed()
def _check_availability(self) -> None:
"""Check that snowflake-connector-python is installed."""
if not SNOWFLAKE_AVAILABLE:
raise SnowflakeNotAvailableError(
"snowflake-connector-python is not installed. "
"Install it with: pip install snowflake-connector-python"
)
def _check_configured(self) -> None:
"""Check that Snowflake is configured."""
if not self._config.is_configured:
raise SnowflakeNotConfiguredError(
"Snowflake account is not configured. "
"Edit config/snowflake.toml and set connection.account"
)
def connect(self) -> SnowflakeConnection:
"""
Establish a connection to Snowflake.
Uses the externalbrowser authenticator which opens a browser window
for NHS SSO authentication. The browser popup is expected and normal.
Returns:
Active SnowflakeConnection
Raises:
SnowflakeNotAvailableError: If snowflake-connector-python not installed
SnowflakeNotConfiguredError: If account is not configured
SnowflakeConnectionError: If connection fails
"""
self._check_availability()
self._check_configured()
# Close existing connection if any
if self._connection is not None:
self.close()
conn_cfg = self._config.connection
timeout_cfg = self._config.timeouts
logger.info(f"Connecting to Snowflake account: {conn_cfg.account}")
logger.info(f"Using warehouse: {conn_cfg.warehouse}, database: {conn_cfg.database}")
logger.info(f"Authenticator: {conn_cfg.authenticator}")
if conn_cfg.authenticator == "externalbrowser":
logger.info("Browser window will open for NHS SSO authentication")
start_time = time.time()
try:
# Build connection parameters
connect_params = {
"account": conn_cfg.account,
"warehouse": conn_cfg.warehouse,
"database": conn_cfg.database,
"schema": conn_cfg.schema,
"authenticator": conn_cfg.authenticator,
"login_timeout": timeout_cfg.login_timeout,
"network_timeout": timeout_cfg.connection_timeout,
}
# Optional parameters (only add if set)
if conn_cfg.user:
connect_params["user"] = conn_cfg.user
if conn_cfg.role:
connect_params["role"] = conn_cfg.role
self._connection = snowflake.connector.connect(**connect_params)
elapsed = time.time() - start_time
logger.info(f"Connected to Snowflake successfully in {elapsed:.1f}s")
# Update connection info
self._connection_info = ConnectionInfo(
connected=True,
account=conn_cfg.account,
warehouse=conn_cfg.warehouse,
database=conn_cfg.database,
schema=conn_cfg.schema,
user=self._get_current_user(),
role=self._get_current_role(),
connected_at=datetime.now(),
query_count=0,
)
return self._connection
except Exception as e:
elapsed = time.time() - start_time
logger.error(f"Failed to connect to Snowflake after {elapsed:.1f}s: {e}")
self._connection_info = ConnectionInfo(connected=False)
raise SnowflakeConnectionError(f"Failed to connect to Snowflake: {e}") from e
def close(self) -> None:
"""Close the Snowflake connection if open."""
if self._connection is not None:
try:
self._connection.close()
logger.info("Snowflake connection closed")
except Exception as e:
logger.warning(f"Error closing Snowflake connection: {e}")
finally:
self._connection = None
self._connection_info = ConnectionInfo(connected=False)
def _get_current_user(self) -> str:
"""Get the current authenticated user."""
if self._connection is None:
return ""
try:
cursor = self._connection.cursor()
cursor.execute("SELECT CURRENT_USER()")
result = cursor.fetchone()
return result[0] if result else ""
except Exception:
return ""
def _get_current_role(self) -> str:
"""Get the current active role."""
if self._connection is None:
return ""
try:
cursor = self._connection.cursor()
cursor.execute("SELECT CURRENT_ROLE()")
result = cursor.fetchone()
return result[0] if result else ""
except Exception:
return ""
@contextmanager
def get_connection(self) -> Generator[SnowflakeConnection, None, None]:
"""
Context manager for connection handling.
Creates a new connection if not already connected, yields the connection,
and ensures proper cleanup on exit.
Yields:
Active SnowflakeConnection
Example:
connector = SnowflakeConnector()
with connector.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT 1")
"""
if not self.is_connected:
self.connect()
assert self._connection is not None, "Connection should be established"
try:
yield self._connection
finally:
# Keep connection open for reuse
pass
@contextmanager
def get_cursor(
self,
dict_cursor: bool = False
) -> Generator[SnowflakeCursor, None, None]:
"""
Context manager that provides a cursor.
Args:
dict_cursor: If True, returns cursor that yields dict-like rows
Yields:
SnowflakeCursor for executing queries
Example:
connector = SnowflakeConnector()
with connector.get_cursor() as cursor:
cursor.execute("SELECT * FROM table LIMIT 10")
for row in cursor:
print(row)
"""
if not self.is_connected:
self.connect()
assert self._connection is not None, "Connection should be established"
cursor: Any = None
try:
if dict_cursor:
cursor = self._connection.cursor(snowflake.connector.DictCursor) # type: ignore[union-attr]
else:
cursor = self._connection.cursor()
yield cursor # type: ignore[misc]
self._connection_info.last_query_at = datetime.now()
self._connection_info.query_count += 1
finally:
if cursor is not None:
cursor.close()
def execute(
self,
query: str,
params: Optional[tuple] = None,
timeout: Optional[int] = None
) -> list[tuple]:
"""
Execute a query and return all results.
Args:
query: SQL query to execute
params: Optional query parameters for parameterized queries
timeout: Optional query timeout in seconds (overrides config)
Returns:
List of result rows as tuples
Raises:
SnowflakeConnectionError: If not connected
Various snowflake errors for query issues
"""
if not self.is_connected:
self.connect()
effective_timeout = timeout or self._config.timeouts.query_timeout
with self.get_cursor() as cursor:
logger.info(f"Executing query (timeout={effective_timeout}s)")
logger.debug(f"Query: {query[:200]}...")
if effective_timeout > 0:
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
start_time = time.time()
cursor.execute(query, params)
results = cursor.fetchall()
elapsed = time.time() - start_time
logger.info(f"Query returned {len(results)} rows in {elapsed:.2f}s")
return results
def execute_dict(
self,
query: str,
params: Optional[tuple] = None,
timeout: Optional[int] = None
) -> list[dict]:
"""
Execute a query and return results as list of dictionaries.
Args:
query: SQL query to execute
params: Optional query parameters
timeout: Optional query timeout in seconds
Returns:
List of result rows as dictionaries
"""
if not self.is_connected:
self.connect()
effective_timeout = timeout or self._config.timeouts.query_timeout
with self.get_cursor(dict_cursor=True) as cursor:
logger.info(f"Executing query (timeout={effective_timeout}s)")
logger.debug(f"Query: {query[:200]}...")
if effective_timeout > 0:
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
start_time = time.time()
cursor.execute(query, params)
results = cursor.fetchall()
elapsed = time.time() - start_time
logger.info(f"Query returned {len(results)} rows in {elapsed:.2f}s")
return results # type: ignore[return-value]
def execute_chunked(
self,
query: str,
params: Optional[tuple] = None,
chunk_size: Optional[int] = None,
timeout: Optional[int] = None,
max_rows: Optional[int] = None,
) -> Generator[list[tuple], None, None]:
"""
Execute a query and yield results in chunks for memory efficiency.
This method is useful for large result sets that would exceed memory
if loaded all at once. Results are yielded as chunks of rows.
Args:
query: SQL query to execute
params: Optional query parameters for parameterized queries
chunk_size: Number of rows per chunk (default from config)
timeout: Optional query timeout in seconds (overrides config)
max_rows: Maximum total rows to return (default from config, 0 for no limit)
Yields:
List of result rows as tuples for each chunk
Example:
for chunk in connector.execute_chunked("SELECT * FROM large_table"):
process_chunk(chunk)
"""
if not self.is_connected:
self.connect()
effective_timeout = timeout or self._config.timeouts.query_timeout
effective_chunk_size = chunk_size or self._config.query.chunk_size
effective_max_rows = max_rows if max_rows is not None else self._config.query.max_rows
with self.get_cursor() as cursor:
logger.info(f"Executing chunked query (chunk_size={effective_chunk_size}, timeout={effective_timeout}s)")
logger.debug(f"Query: {query[:200]}...")
if effective_timeout > 0:
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
start_time = time.time()
cursor.execute(query, params)
total_rows = 0
chunk_num = 0
while True:
# Determine how many rows to fetch this chunk
if effective_max_rows > 0:
remaining = effective_max_rows - total_rows
if remaining <= 0:
break
fetch_size = min(effective_chunk_size, remaining)
else:
fetch_size = effective_chunk_size
chunk = cursor.fetchmany(fetch_size)
if not chunk:
break
chunk_num += 1
total_rows += len(chunk)
logger.debug(f"Chunk {chunk_num}: {len(chunk)} rows (total: {total_rows})")
yield chunk
elapsed = time.time() - start_time
logger.info(f"Chunked query returned {total_rows} rows in {chunk_num} chunks ({elapsed:.2f}s)")
def execute_chunked_dict(
self,
query: str,
params: Optional[tuple] = None,
chunk_size: Optional[int] = None,
timeout: Optional[int] = None,
max_rows: Optional[int] = None,
) -> Generator[list[dict], None, None]:
"""
Execute a query and yield dict results in chunks for memory efficiency.
Same as execute_chunked but returns rows as dictionaries.
Args:
query: SQL query to execute
params: Optional query parameters
chunk_size: Number of rows per chunk (default from config)
timeout: Optional query timeout in seconds
max_rows: Maximum total rows to return (default from config, 0 for no limit)
Yields:
List of result rows as dictionaries for each chunk
"""
if not self.is_connected:
self.connect()
effective_timeout = timeout or self._config.timeouts.query_timeout
effective_chunk_size = chunk_size or self._config.query.chunk_size
effective_max_rows = max_rows if max_rows is not None else self._config.query.max_rows
with self.get_cursor(dict_cursor=True) as cursor:
logger.info(f"Executing chunked dict query (chunk_size={effective_chunk_size}, timeout={effective_timeout}s)")
logger.debug(f"Query: {query[:200]}...")
if effective_timeout > 0:
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
start_time = time.time()
cursor.execute(query, params)
total_rows = 0
chunk_num = 0
while True:
# Determine how many rows to fetch this chunk
if effective_max_rows > 0:
remaining = effective_max_rows - total_rows
if remaining <= 0:
break
fetch_size = min(effective_chunk_size, remaining)
else:
fetch_size = effective_chunk_size
chunk = cursor.fetchmany(fetch_size)
if not chunk:
break
chunk_num += 1
total_rows += len(chunk)
logger.debug(f"Chunk {chunk_num}: {len(chunk)} rows (total: {total_rows})")
yield chunk # type: ignore[misc]
elapsed = time.time() - start_time
logger.info(f"Chunked dict query returned {total_rows} rows in {chunk_num} chunks ({elapsed:.2f}s)")
def execute_with_row_limit(
self,
query: str,
params: Optional[tuple] = None,
max_rows: Optional[int] = None,
timeout: Optional[int] = None
) -> tuple[list[dict], bool]:
"""
Execute a query with a row limit and indicate if more rows were available.
This is useful for pagination or previewing large result sets.
Args:
query: SQL query to execute
params: Optional query parameters
max_rows: Maximum rows to return (default from config)
timeout: Optional query timeout in seconds
Returns:
Tuple of (results list, has_more bool)
- results: List of result rows as dictionaries (up to max_rows)
- has_more: True if there were more rows than max_rows
"""
if not self.is_connected:
self.connect()
effective_timeout = timeout or self._config.timeouts.query_timeout
effective_max_rows = max_rows if max_rows is not None else self._config.query.max_rows
with self.get_cursor(dict_cursor=True) as cursor:
logger.info(f"Executing query with limit (max_rows={effective_max_rows}, timeout={effective_timeout}s)")
logger.debug(f"Query: {query[:200]}...")
if effective_timeout > 0:
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
start_time = time.time()
cursor.execute(query, params)
# Fetch one more than max to detect if there are more rows
results = cursor.fetchmany(effective_max_rows + 1)
elapsed = time.time() - start_time
has_more = len(results) > effective_max_rows
if has_more:
results = results[:effective_max_rows]
logger.info(f"Query returned {len(results)} rows (has_more={has_more}) in {elapsed:.2f}s")
return results, has_more # type: ignore[return-value]
def fetch_activity_data(
self,
start_date: Optional[date] = None,
end_date: Optional[date] = None,
provider_codes: Optional[list[str]] = None,
max_rows: Optional[int] = None,
timeout: Optional[int] = None,
) -> list[dict]:
"""
Fetch high-cost drug activity data from Snowflake.
Queries the CDM.Acute__Conmon__PatientLevelDrugs table and returns
data in a format compatible with the existing analysis pipeline.
Args:
start_date: Optional start date for filtering (inclusive)
end_date: Optional end date for filtering (inclusive)
provider_codes: Optional list of provider codes to filter by
max_rows: Maximum rows to return (default from config)
timeout: Query timeout in seconds (default from config)
Returns:
List of dictionaries with keys matching expected DataFrame columns:
- PseudoNHSNoLinked: Pseudonymised NHS number (for UPID creation)
- Provider Code: NHS provider code
- PersonKey: Local patient identifier
- Drug Name: Raw drug name
- Intervention Date: Date of intervention
- Price Actual: Cost of intervention
- OrganisationName: Provider organisation name
- Treatment Function Code: NHS treatment function code
- Additional Detail 1-5: Additional details for directory identification
Raises:
SnowflakeConnectionError: If not connected or query fails
"""
if not self.is_connected:
self.connect()
# Build the query
table_name = 'DATA_HUB.CDM."Acute__Conmon__PatientLevelDrugs"'
query = f'''
SELECT
"PseudoNHSNoLinked",
"ProviderCode" AS "Provider Code",
"LocalPatientID" AS "PersonKey",
"DrugName" AS "Drug Name",
"InterventionDate" AS "Intervention Date",
"PriceActual" AS "Price Actual",
"ProviderName" AS "OrganisationName",
"TreatmentFunctionCode" AS "Treatment Function Code",
"TreatmentFunctionDesc" AS "Treatment Function Desc",
"AdditionalDetail1" AS "Additional Detail 1",
"AdditionalDescription1" AS "Additional Description 1",
"AdditionalDetail2" AS "Additional Detail 2",
"AdditionalDescription2" AS "Additional Description 2",
"AdditionalDetail3" AS "Additional Detail 3",
"AdditionalDescription3" AS "Additional Description 3",
"AdditionalDetail4" AS "Additional Detail 4",
"AdditionalDescription4" AS "Additional Description 4",
"AdditionalDetail5" AS "Additional Detail 5",
"AdditionalDescription5" AS "Additional Description 5"
FROM {table_name}
WHERE 1=1
'''
params = []
# Add date filters
if start_date:
query += ' AND "InterventionDate" >= %s'
params.append(start_date.isoformat())
if end_date:
query += ' AND "InterventionDate" <= %s'
params.append(end_date.isoformat())
# Add provider filter
if provider_codes:
placeholders = ", ".join(["%s"] * len(provider_codes))
query += f' AND "ProviderCode" IN ({placeholders})'
params.extend(provider_codes)
# Add ordering for consistent results
query += ' ORDER BY "InterventionDate", "ProviderCode", "PseudoNHSNoLinked"'
logger.info(f"Fetching activity data from Snowflake")
if start_date:
logger.info(f" Date range: {start_date} to {end_date or 'now'}")
if provider_codes:
logger.info(f" Providers: {provider_codes}")
effective_max_rows = max_rows if max_rows is not None else self._config.query.max_rows
effective_timeout = timeout or self._config.timeouts.query_timeout
# Execute with chunked results for large datasets
all_results = []
total_rows = 0
for chunk in self.execute_chunked_dict(
query,
params=tuple(params) if params else None,
timeout=effective_timeout,
max_rows=effective_max_rows,
):
all_results.extend(chunk)
total_rows += len(chunk)
logger.debug(f"Fetched {total_rows} rows so far...")
logger.info(f"Fetched {len(all_results)} activity records from Snowflake")
return all_results
def test_connection(self) -> tuple[bool, str]:
"""
Test the Snowflake connection.
Returns:
Tuple of (success: bool, message: str)
"""
try:
self._check_availability()
except SnowflakeNotAvailableError as e:
return False, str(e)
try:
self._check_configured()
except SnowflakeNotConfiguredError as e:
return False, str(e)
try:
self.connect()
user = self._get_current_user()
role = self._get_current_role()
return True, f"Connected as {user} with role {role}"
except Exception as e:
return False, f"Connection failed: {e}"
def __enter__(self) -> "SnowflakeConnector":
"""Context manager entry."""
self.connect()
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Context manager exit."""
self.close()
# Module-level singleton for convenience
_default_connector: Optional[SnowflakeConnector] = None
def get_connector(config: Optional[SnowflakeConfig] = None) -> SnowflakeConnector:
"""
Get a Snowflake connector (creates singleton on first call).
Args:
config: Optional configuration. If provided, creates new connector
with this config. If None, uses/creates default connector.
Returns:
SnowflakeConnector instance
"""
global _default_connector
if config is not None:
# Custom config requested, create new connector
return SnowflakeConnector(config)
if _default_connector is None:
_default_connector = SnowflakeConnector()
return _default_connector
def reset_connector() -> None:
"""Reset the default connector (closes connection and clears singleton)."""
global _default_connector
if _default_connector is not None:
_default_connector.close()
_default_connector = None
def is_snowflake_available() -> bool:
"""Return True if snowflake-connector-python is installed."""
return SNOWFLAKE_AVAILABLE
def is_snowflake_configured() -> bool:
"""Return True if Snowflake account is configured."""
try:
config = get_snowflake_config()
return config.is_configured
except Exception:
return False
# Export public API
__all__ = [
"SnowflakeConnector",
"SnowflakeConnectionError",
"SnowflakeNotConfiguredError",
"SnowflakeNotAvailableError",
"ConnectionInfo",
"get_connector",
"reset_connector",
"is_snowflake_available",
"is_snowflake_configured",
"SNOWFLAKE_AVAILABLE",
]