Restructured src to more logical heirachy

This commit is contained in:
2026-02-09 16:22:05 +00:00
parent 7e63e6ea45
commit fcbde7c689
35 changed files with 0 additions and 0 deletions
+42
View File
@@ -0,0 +1,42 @@
# data_processing Package
Data layer for NHS High-Cost Drug Patient Pathway Analysis Tool.
## Core Responsibilities
**Data Pipeline:** `Snowflake → Transforms → Pathway Generation → SQLite`
## Key Modules
**transforms.py** — Core data transformations (moved from tools/data.py):
- `patient_id()` — Creates UPID = Provider Code (first 3 chars) + PersonKey
- `drug_names()` — Standardizes drug names via drugnames.csv lookup
- `department_identification()` — 5-level fallback chain for directory assignment
**pathway_pipeline.py** — Pipeline orchestration:
- Processes 6 date filter combinations × 2 chart types (directory + indication)
- `fetch_and_transform_data()` — Snowflake fetch + UPID/drug/directory transforms
- `process_pathway_for_date_filter()` — Directory charts using `generate_icicle_chart()`
- `process_indication_pathway_for_date_filter()` — Indication charts using `generate_icicle_chart_indication()`
- `insert_pathway_records()` — SQLite insertion with parameterized queries
**diagnosis_lookup.py** — GP diagnosis matching:
- `get_patient_indication_groups()` — Batch queries Snowflake (500 patients at a time)
- Embeds ~148 Search_Term → Cluster_ID mappings as SQL CTE
- Returns most recent match per patient via `QUALIFY ROW_NUMBER()`
**database.py** — SQLite connection pooling and transaction management
**schema.py** — SQL schema definitions (reference tables + pathway_nodes)
**snowflake_connector.py** — Snowflake SSO integration via externalbrowser authenticator
**cache.py** — Query result caching with TTL-based invalidation
## Import Pattern
All imports use package names directly:
```python
from data_processing.transforms import patient_id, drug_names, department_identification
from data_processing.pathway_pipeline import process_all_date_filters
```
+208
View File
@@ -0,0 +1,208 @@
"""
Data processing module for NHS High-Cost Drug Patient Pathway Analysis Tool.
Contains SQLite database management, data loaders, and Snowflake integration.
Handles the migration from CSV-based storage to SQLite for improved performance.
Submodules:
database: SQLite connection management and schema definitions
loader: Data loading abstractions (CSV, SQLite, Snowflake)
snowflake_connector: Snowflake integration with SSO authentication
"""
from data_processing.database import (
DatabaseConfig,
DatabaseManager,
default_db_config,
default_db_manager,
)
from data_processing.schema import (
# Reference table schemas
REF_DRUG_NAMES_SCHEMA,
REF_ORGANIZATIONS_SCHEMA,
REF_DIRECTORIES_SCHEMA,
REF_DRUG_DIRECTORY_MAP_SCHEMA,
REF_DRUG_INDICATION_CLUSTERS_SCHEMA,
REFERENCE_TABLES_SCHEMA,
# Combined schema
ALL_TABLES_SCHEMA,
# Reference table functions
create_reference_tables,
drop_reference_tables,
get_reference_table_counts,
verify_reference_tables_exist,
# Combined functions
create_all_tables,
drop_all_tables,
get_all_table_counts,
verify_all_tables_exist,
)
# Reference data migration functions
from data_processing.reference_data import (
MigrationResult,
migrate_drug_names,
get_drug_name_counts,
verify_drug_names_migration,
migrate_organizations,
get_organization_counts,
verify_organizations_migration,
migrate_directories,
get_directory_counts,
verify_directories_migration,
migrate_drug_directory_map,
get_drug_directory_map_counts,
verify_drug_directory_map_migration,
migrate_drug_indication_clusters,
get_drug_indication_cluster_counts,
verify_drug_indication_clusters_migration,
)
# Data loader abstractions
from data_processing.loader import (
DataLoader,
FileDataLoader,
LoadResult,
get_loader,
REQUIRED_COLUMNS,
OPTIONAL_COLUMNS,
)
# Snowflake connector
from data_processing.snowflake_connector import (
SnowflakeConnector,
SnowflakeConnectionError,
SnowflakeNotConfiguredError,
SnowflakeNotAvailableError,
ConnectionInfo,
get_connector,
reset_connector,
is_snowflake_available,
is_snowflake_configured,
SNOWFLAKE_AVAILABLE,
)
# Query result caching
from data_processing.cache import (
QueryCache,
CacheEntry,
CacheStats,
get_cache,
reset_cache,
is_cache_enabled,
)
# Data source management with fallback chain
from data_processing.data_source import (
DataSourceType,
DataSourceResult,
SourceStatus,
DataSourceManager,
get_data_source_manager,
get_data,
reset_data_source_manager,
)
# Diagnosis lookup (GP diagnosis validation)
from data_processing.diagnosis_lookup import (
ClusterSnomedCodes,
IndicationValidationResult,
DrugIndicationMatchRate,
get_drug_clusters,
get_drug_cluster_ids,
get_cluster_snomed_codes,
patient_has_indication,
validate_indication,
get_indication_match_rate,
batch_validate_indications,
get_available_clusters,
)
__all__ = [
# Database management
"DatabaseConfig",
"DatabaseManager",
"default_db_config",
"default_db_manager",
# Reference table schemas
"REF_DRUG_NAMES_SCHEMA",
"REF_ORGANIZATIONS_SCHEMA",
"REF_DIRECTORIES_SCHEMA",
"REF_DRUG_DIRECTORY_MAP_SCHEMA",
"REF_DRUG_INDICATION_CLUSTERS_SCHEMA",
"REFERENCE_TABLES_SCHEMA",
# Combined schema
"ALL_TABLES_SCHEMA",
# Reference table functions
"create_reference_tables",
"drop_reference_tables",
"get_reference_table_counts",
"verify_reference_tables_exist",
# Combined functions
"create_all_tables",
"drop_all_tables",
"get_all_table_counts",
"verify_all_tables_exist",
# Reference data migration
"MigrationResult",
"migrate_drug_names",
"get_drug_name_counts",
"verify_drug_names_migration",
"migrate_organizations",
"get_organization_counts",
"verify_organizations_migration",
"migrate_directories",
"get_directory_counts",
"verify_directories_migration",
"migrate_drug_directory_map",
"get_drug_directory_map_counts",
"verify_drug_directory_map_migration",
"migrate_drug_indication_clusters",
"get_drug_indication_cluster_counts",
"verify_drug_indication_clusters_migration",
# Data loader abstractions
"DataLoader",
"FileDataLoader",
"LoadResult",
"get_loader",
"REQUIRED_COLUMNS",
"OPTIONAL_COLUMNS",
# Snowflake connector
"SnowflakeConnector",
"SnowflakeConnectionError",
"SnowflakeNotConfiguredError",
"SnowflakeNotAvailableError",
"ConnectionInfo",
"get_connector",
"reset_connector",
"is_snowflake_available",
"is_snowflake_configured",
"SNOWFLAKE_AVAILABLE",
# Query result caching
"QueryCache",
"CacheEntry",
"CacheStats",
"get_cache",
"reset_cache",
"is_cache_enabled",
# Data source management with fallback chain
"DataSourceType",
"DataSourceResult",
"SourceStatus",
"DataSourceManager",
"get_data_source_manager",
"get_data",
"reset_data_source_manager",
# Diagnosis lookup
"ClusterSnomedCodes",
"IndicationValidationResult",
"DrugIndicationMatchRate",
"get_drug_clusters",
"get_drug_cluster_ids",
"get_cluster_snomed_codes",
"patient_has_indication",
"validate_indication",
"get_indication_match_rate",
"batch_validate_indications",
"get_available_clusters",
]
+553
View File
@@ -0,0 +1,553 @@
"""
Query result caching module for NHS Patient Pathway Analysis.
Provides file-based caching for Snowflake query results with TTL-based invalidation.
Supports different TTLs for historical data vs data including the current date.
Cache keys are generated from query hashes. Results are stored as compressed JSON.
Usage:
from data_processing.cache import QueryCache, get_cache
cache = get_cache()
# Check for cached result
result = cache.get(query, params)
if result is None:
# Execute query and cache result
result = execute_query(query, params)
cache.set(query, params, result, includes_current_data=False)
"""
from dataclasses import dataclass
from datetime import datetime, date
from pathlib import Path
from typing import Any, Optional
import gzip
import hashlib
import json
import os
import time
from config import get_snowflake_config, CacheConfig
from core.logging_config import get_logger
logger = get_logger(__name__)
@dataclass
class CacheEntry:
"""Metadata for a cached query result."""
cache_key: str
query_hash: str
created_at: datetime
expires_at: datetime
includes_current_data: bool
row_count: int
file_size_bytes: int
file_path: Path
@dataclass
class CacheStats:
"""Statistics about the cache."""
enabled: bool
cache_dir: Path
total_entries: int
total_size_mb: float
max_size_mb: int
oldest_entry: Optional[datetime]
newest_entry: Optional[datetime]
hit_count: int
miss_count: int
class QueryCache:
"""
File-based cache for Snowflake query results.
Results are stored as gzipped JSON files with TTL-based expiration.
Supports different TTLs for historical vs current data.
Attributes:
config: CacheConfig with cache settings
cache_dir: Path to cache directory
"""
def __init__(self, config: Optional[CacheConfig] = None, base_path: Optional[Path] = None):
"""
Initialize the query cache.
Args:
config: Optional CacheConfig. If not provided, loads from snowflake.toml
base_path: Base path for relative cache directory. Defaults to cwd.
"""
if config is None:
sf_config = get_snowflake_config()
config = sf_config.cache
self._config = config
self._base_path = base_path or Path.cwd()
# Resolve cache directory
cache_dir = Path(config.directory)
if not cache_dir.is_absolute():
cache_dir = self._base_path / cache_dir
self._cache_dir = cache_dir
# Stats tracking (in-memory only, reset on restart)
self._hit_count = 0
self._miss_count = 0
# Ensure cache directory exists if enabled
if self._config.enabled:
self._cache_dir.mkdir(parents=True, exist_ok=True)
@property
def config(self) -> CacheConfig:
"""Return the cache configuration."""
return self._config
@property
def cache_dir(self) -> Path:
"""Return the cache directory path."""
return self._cache_dir
@property
def is_enabled(self) -> bool:
"""Return True if caching is enabled."""
return self._config.enabled
def _generate_cache_key(self, query: str, params: Optional[tuple] = None) -> str:
"""
Generate a cache key from query and parameters.
Uses SHA256 hash of query + params to create unique key.
"""
# Normalize query (strip whitespace, lowercase)
normalized_query = " ".join(query.lower().split())
# Combine query and params
key_content = normalized_query
if params:
key_content += "|" + "|".join(str(p) for p in params)
# Hash to create key
hash_obj = hashlib.sha256(key_content.encode("utf-8"))
return hash_obj.hexdigest()[:32] # Use first 32 chars for readability
def _get_cache_file_path(self, cache_key: str) -> Path:
"""Get the file path for a cache entry."""
return self._cache_dir / f"{cache_key}.json.gz"
def _get_meta_file_path(self, cache_key: str) -> Path:
"""Get the metadata file path for a cache entry."""
return self._cache_dir / f"{cache_key}.meta.json"
def _is_expired(self, meta: dict) -> bool:
"""Check if a cache entry is expired based on its metadata."""
expires_at = datetime.fromisoformat(meta["expires_at"])
return datetime.now() > expires_at
def get(
self,
query: str,
params: Optional[tuple] = None,
check_expiry: bool = True
) -> Optional[list[dict]]:
"""
Get a cached query result.
Args:
query: SQL query string
params: Optional query parameters
check_expiry: If True, returns None for expired entries
Returns:
Cached result as list of dicts, or None if not cached/expired
"""
if not self.is_enabled:
self._miss_count += 1
return None
cache_key = self._generate_cache_key(query, params)
cache_file = self._get_cache_file_path(cache_key)
meta_file = self._get_meta_file_path(cache_key)
# Check if files exist
if not cache_file.exists() or not meta_file.exists():
self._miss_count += 1
logger.debug(f"Cache miss (not found): {cache_key}")
return None
# Load and check metadata
try:
with open(meta_file, "r", encoding="utf-8") as f:
meta = json.load(f)
if check_expiry and self._is_expired(meta):
self._miss_count += 1
logger.debug(f"Cache miss (expired): {cache_key}")
return None
# Load cached data
with gzip.open(cache_file, "rt", encoding="utf-8") as f:
data = json.load(f)
self._hit_count += 1
logger.info(f"Cache hit: {cache_key} ({meta['row_count']} rows)")
return data
except (json.JSONDecodeError, KeyError, OSError) as e:
logger.warning(f"Cache read error for {cache_key}: {e}")
self._miss_count += 1
# Clean up corrupted entry
self._delete_entry(cache_key)
return None
def set(
self,
query: str,
params: Optional[tuple],
data: list[dict],
includes_current_data: bool = False,
custom_ttl_seconds: Optional[int] = None
) -> Optional[CacheEntry]:
"""
Cache a query result.
Args:
query: SQL query string
params: Optional query parameters
data: Query result as list of dicts
includes_current_data: If True, uses shorter TTL for current data
custom_ttl_seconds: Optional custom TTL (overrides config)
Returns:
CacheEntry with metadata, or None if caching disabled/failed
"""
if not self.is_enabled:
return None
cache_key = self._generate_cache_key(query, params)
cache_file = self._get_cache_file_path(cache_key)
meta_file = self._get_meta_file_path(cache_key)
# Determine TTL
if custom_ttl_seconds is not None:
ttl = custom_ttl_seconds
elif includes_current_data:
ttl = self._config.ttl_current_data_seconds
else:
ttl = self._config.ttl_seconds
now = datetime.now()
expires_at = datetime.fromtimestamp(now.timestamp() + ttl)
try:
# Write compressed data
with gzip.open(cache_file, "wt", encoding="utf-8", compresslevel=6) as f:
json.dump(data, f, default=str)
file_size = cache_file.stat().st_size
# Write metadata
meta = {
"cache_key": cache_key,
"query_hash": hashlib.sha256(query.encode()).hexdigest()[:16],
"created_at": now.isoformat(),
"expires_at": expires_at.isoformat(),
"includes_current_data": includes_current_data,
"row_count": len(data),
"file_size_bytes": file_size,
"ttl_seconds": ttl,
}
with open(meta_file, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
logger.info(f"Cached {len(data)} rows as {cache_key} (expires in {ttl}s)")
# Check if we need to enforce size limit
self._enforce_size_limit()
return CacheEntry(
cache_key=cache_key,
query_hash=str(meta["query_hash"]),
created_at=now,
expires_at=expires_at,
includes_current_data=includes_current_data,
row_count=len(data),
file_size_bytes=file_size,
file_path=cache_file,
)
except (OSError, TypeError) as e:
logger.error(f"Failed to cache result: {e}")
return None
def invalidate(self, query: str, params: Optional[tuple] = None) -> bool:
"""
Invalidate a specific cache entry.
Args:
query: SQL query string
params: Optional query parameters
Returns:
True if entry was deleted, False if not found
"""
cache_key = self._generate_cache_key(query, params)
return self._delete_entry(cache_key)
def _delete_entry(self, cache_key: str) -> bool:
"""Delete a cache entry by key."""
cache_file = self._get_cache_file_path(cache_key)
meta_file = self._get_meta_file_path(cache_key)
deleted = False
if cache_file.exists():
cache_file.unlink()
deleted = True
if meta_file.exists():
meta_file.unlink()
deleted = True
if deleted:
logger.debug(f"Deleted cache entry: {cache_key}")
return deleted
def clear(self) -> int:
"""
Clear all cache entries.
Returns:
Number of entries deleted
"""
if not self._cache_dir.exists():
return 0
count = 0
for file in self._cache_dir.glob("*.json*"):
try:
file.unlink()
count += 1
except OSError as e:
logger.warning(f"Failed to delete {file}: {e}")
# Reset stats
self._hit_count = 0
self._miss_count = 0
logger.info(f"Cleared {count} cache files")
return count // 2 # Divide by 2 since we have .json.gz and .meta.json
def clear_expired(self) -> int:
"""
Remove expired cache entries.
Returns:
Number of expired entries deleted
"""
if not self._cache_dir.exists():
return 0
count = 0
for meta_file in self._cache_dir.glob("*.meta.json"):
try:
with open(meta_file, "r", encoding="utf-8") as f:
meta = json.load(f)
if self._is_expired(meta):
cache_key = meta_file.stem.replace(".meta", "")
self._delete_entry(cache_key)
count += 1
except (OSError, json.JSONDecodeError):
# Delete corrupted metadata files
cache_key = meta_file.stem.replace(".meta", "")
self._delete_entry(cache_key)
count += 1
logger.info(f"Cleared {count} expired cache entries")
return count
def _get_total_size_mb(self) -> float:
"""Calculate total cache size in MB."""
if not self._cache_dir.exists():
return 0.0
total_bytes = sum(
f.stat().st_size
for f in self._cache_dir.glob("*")
if f.is_file()
)
return total_bytes / (1024 * 1024)
def _enforce_size_limit(self) -> int:
"""
Enforce cache size limit by removing oldest entries.
Returns:
Number of entries removed
"""
max_size_mb = self._config.max_size_mb
current_size_mb = self._get_total_size_mb()
if current_size_mb <= max_size_mb:
return 0
# Get all entries sorted by creation time
entries = []
for meta_file in self._cache_dir.glob("*.meta.json"):
try:
with open(meta_file, "r", encoding="utf-8") as f:
meta = json.load(f)
entries.append((
meta_file.stem.replace(".meta", ""),
datetime.fromisoformat(meta["created_at"]),
meta.get("file_size_bytes", 0)
))
except (OSError, json.JSONDecodeError, KeyError):
# Clean up corrupted entry
cache_key = meta_file.stem.replace(".meta", "")
self._delete_entry(cache_key)
# Sort by creation time (oldest first)
entries.sort(key=lambda x: x[1])
# Remove oldest entries until under limit
removed = 0
size_to_remove_bytes = (current_size_mb - max_size_mb * 0.9) * 1024 * 1024 # Target 90% of limit
removed_bytes = 0
for cache_key, created_at, file_size in entries:
if removed_bytes >= size_to_remove_bytes:
break
self._delete_entry(cache_key)
removed_bytes += file_size
removed += 1
logger.info(f"Removed {removed} cache entries to enforce size limit")
return removed
def get_stats(self) -> CacheStats:
"""Get cache statistics."""
if not self._cache_dir.exists():
return CacheStats(
enabled=self.is_enabled,
cache_dir=self._cache_dir,
total_entries=0,
total_size_mb=0.0,
max_size_mb=self._config.max_size_mb,
oldest_entry=None,
newest_entry=None,
hit_count=self._hit_count,
miss_count=self._miss_count,
)
entries = []
for meta_file in self._cache_dir.glob("*.meta.json"):
try:
with open(meta_file, "r", encoding="utf-8") as f:
meta = json.load(f)
entries.append(datetime.fromisoformat(meta["created_at"]))
except (OSError, json.JSONDecodeError, KeyError):
pass
oldest = min(entries) if entries else None
newest = max(entries) if entries else None
return CacheStats(
enabled=self.is_enabled,
cache_dir=self._cache_dir,
total_entries=len(entries),
total_size_mb=self._get_total_size_mb(),
max_size_mb=self._config.max_size_mb,
oldest_entry=oldest,
newest_entry=newest,
hit_count=self._hit_count,
miss_count=self._miss_count,
)
def list_entries(self) -> list[CacheEntry]:
"""List all cache entries with metadata."""
if not self._cache_dir.exists():
return []
entries = []
for meta_file in self._cache_dir.glob("*.meta.json"):
try:
with open(meta_file, "r", encoding="utf-8") as f:
meta = json.load(f)
cache_key = meta["cache_key"]
entries.append(CacheEntry(
cache_key=cache_key,
query_hash=meta.get("query_hash", ""),
created_at=datetime.fromisoformat(meta["created_at"]),
expires_at=datetime.fromisoformat(meta["expires_at"]),
includes_current_data=meta.get("includes_current_data", False),
row_count=meta.get("row_count", 0),
file_size_bytes=meta.get("file_size_bytes", 0),
file_path=self._get_cache_file_path(cache_key),
))
except (OSError, json.JSONDecodeError, KeyError):
pass
# Sort by creation time (newest first)
entries.sort(key=lambda x: x.created_at, reverse=True)
return entries
# Module-level singleton
_default_cache: Optional[QueryCache] = None
def get_cache(config: Optional[CacheConfig] = None) -> QueryCache:
"""
Get a QueryCache instance (creates singleton on first call).
Args:
config: Optional CacheConfig. If provided, creates new cache with
this config. If None, uses/creates default cache.
Returns:
QueryCache instance
"""
global _default_cache
if config is not None:
# Custom config requested, create new cache
return QueryCache(config)
if _default_cache is None:
_default_cache = QueryCache()
return _default_cache
def reset_cache() -> None:
"""Reset the default cache singleton."""
global _default_cache
_default_cache = None
def is_cache_enabled() -> bool:
"""Return True if caching is enabled in configuration."""
config = get_snowflake_config()
return config.cache.enabled
# Export public API
__all__ = [
"QueryCache",
"CacheEntry",
"CacheStats",
"get_cache",
"reset_cache",
"is_cache_enabled",
]
+932
View File
@@ -0,0 +1,932 @@
"""
Unified data access layer with fallback chain for NHS Patient Pathway Analysis.
Provides a high-level interface that automatically selects the best available data source:
1. Cache - Returns cached results if valid and not expired
2. Snowflake - Queries Snowflake warehouse if configured and connected
3. Local - Falls back to SQLite database or CSV/Parquet files
The fallback chain handles connection errors, missing configurations, and
unavailable services gracefully, always attempting to provide data from
some source.
Usage:
from data_processing.data_source import DataSourceManager, get_data
# Simple usage with automatic source selection
result = get_data(
start_date=date(2024, 1, 1),
end_date=date(2024, 12, 31),
trusts=["TRUST A", "TRUST B"],
)
# Or with explicit source preference
manager = DataSourceManager()
result = manager.get_data(
start_date=date(2024, 1, 1),
end_date=date(2024, 12, 31),
preferred_source="snowflake",
)
"""
from dataclasses import dataclass, field
from datetime import date, datetime
from enum import Enum
from pathlib import Path
from typing import Optional, Callable
import pandas as pd
from core.logging_config import get_logger
logger = get_logger(__name__)
class DataSourceType(Enum):
"""Enumeration of available data sources."""
CACHE = "cache"
SNOWFLAKE = "snowflake"
SQLITE = "sqlite"
FILE = "file"
@dataclass
class DataSourceResult:
"""Result from data source query.
Attributes:
df: The loaded DataFrame with patient intervention data
source_type: Which data source was used
source_detail: Additional details about the source (e.g., file path, query hash)
row_count: Number of rows returned
cached: Whether the result came from cache
from_fallback: Whether a fallback source was used
load_time_seconds: Time taken to load data
warnings: Any warnings generated during loading
"""
df: pd.DataFrame
source_type: DataSourceType
source_detail: str = ""
row_count: int = 0
cached: bool = False
from_fallback: bool = False
load_time_seconds: float = 0.0
warnings: list[str] = field(default_factory=list)
def __post_init__(self):
if self.row_count == 0 and self.df is not None:
self.row_count = len(self.df)
@dataclass
class SourceStatus:
"""Status of a data source.
Attributes:
source_type: The type of data source
available: Whether the source is available
configured: Whether the source is properly configured
message: Status message explaining the state
last_checked: When the status was last checked
"""
source_type: DataSourceType
available: bool = False
configured: bool = False
message: str = ""
last_checked: Optional[datetime] = None
class DataSourceManager:
"""
Manages data access with automatic fallback between sources.
The manager attempts to retrieve data from sources in order of preference:
1. Cache (if enabled and has valid cached data)
2. Snowflake (if configured and connected)
3. SQLite (if database exists with data)
4. Local files (CSV/Parquet)
Attributes:
cache_enabled: Whether to use caching
local_file_path: Path to local CSV/Parquet file (optional fallback)
sqlite_db_path: Path to SQLite database (optional)
Example:
manager = DataSourceManager()
# Check what sources are available
status = manager.check_all_sources()
for s in status:
print(f"{s.source_type.value}: {s.message}")
# Get data with automatic fallback
result = manager.get_data(
start_date=date(2024, 1, 1),
end_date=date(2024, 6, 30),
)
print(f"Got {result.row_count} rows from {result.source_type.value}")
"""
def __init__(
self,
cache_enabled: bool = True,
local_file_path: Optional[Path | str] = None,
sqlite_db_path: Optional[Path | str] = None,
):
"""
Initialize the data source manager.
Args:
cache_enabled: Whether to check cache before querying (default True)
local_file_path: Path to local CSV/Parquet file for file fallback
sqlite_db_path: Path to SQLite database (uses default if None)
"""
self._cache_enabled = cache_enabled
self._local_file_path = Path(local_file_path) if local_file_path else None
self._sqlite_db_path = Path(sqlite_db_path) if sqlite_db_path else None
self._source_status: dict[DataSourceType, SourceStatus] = {}
@property
def cache_enabled(self) -> bool:
"""Return whether caching is enabled."""
return self._cache_enabled
@cache_enabled.setter
def cache_enabled(self, value: bool):
"""Set whether caching is enabled."""
self._cache_enabled = value
def _check_cache_status(self) -> SourceStatus:
"""Check if cache is available."""
try:
from data_processing.cache import is_cache_enabled, get_cache
if not is_cache_enabled():
return SourceStatus(
source_type=DataSourceType.CACHE,
available=False,
configured=False,
message="Cache disabled in configuration",
last_checked=datetime.now(),
)
cache = get_cache()
stats = cache.get_stats()
return SourceStatus(
source_type=DataSourceType.CACHE,
available=True,
configured=True,
message=f"Cache enabled ({stats.total_entries} entries, {stats.total_size_mb:.1f}MB)",
last_checked=datetime.now(),
)
except Exception as e:
return SourceStatus(
source_type=DataSourceType.CACHE,
available=False,
configured=False,
message=f"Cache error: {e}",
last_checked=datetime.now(),
)
def _check_snowflake_status(self) -> SourceStatus:
"""Check if Snowflake is available and configured."""
try:
from data_processing.snowflake_connector import (
is_snowflake_available,
is_snowflake_configured,
)
if not is_snowflake_available():
return SourceStatus(
source_type=DataSourceType.SNOWFLAKE,
available=False,
configured=False,
message="snowflake-connector-python not installed",
last_checked=datetime.now(),
)
if not is_snowflake_configured():
return SourceStatus(
source_type=DataSourceType.SNOWFLAKE,
available=True,
configured=False,
message="Snowflake account not configured in config/snowflake.toml",
last_checked=datetime.now(),
)
return SourceStatus(
source_type=DataSourceType.SNOWFLAKE,
available=True,
configured=True,
message="Snowflake configured and ready",
last_checked=datetime.now(),
)
except Exception as e:
return SourceStatus(
source_type=DataSourceType.SNOWFLAKE,
available=False,
configured=False,
message=f"Snowflake error: {e}",
last_checked=datetime.now(),
)
def _check_sqlite_status(self) -> SourceStatus:
"""Check if SQLite database is available with pathway data."""
try:
from data_processing.database import default_db_config
db_path = self._sqlite_db_path or Path(default_db_config.db_path)
if not db_path.exists():
return SourceStatus(
source_type=DataSourceType.SQLITE,
available=False,
configured=True,
message=f"Database not found: {db_path}",
last_checked=datetime.now(),
)
from data_processing.database import DatabaseManager, DatabaseConfig
config = DatabaseConfig(db_path=db_path)
manager = DatabaseManager(config)
if not manager.table_exists("pathway_nodes"):
return SourceStatus(
source_type=DataSourceType.SQLITE,
available=False,
configured=True,
message="pathway_nodes table not found",
last_checked=datetime.now(),
)
count = manager.get_table_count("pathway_nodes")
if count == 0:
return SourceStatus(
source_type=DataSourceType.SQLITE,
available=False,
configured=True,
message="pathway_nodes table is empty",
last_checked=datetime.now(),
)
return SourceStatus(
source_type=DataSourceType.SQLITE,
available=True,
configured=True,
message=f"SQLite database ready ({count:,} pathway nodes)",
last_checked=datetime.now(),
)
except Exception as e:
return SourceStatus(
source_type=DataSourceType.SQLITE,
available=False,
configured=False,
message=f"SQLite error: {e}",
last_checked=datetime.now(),
)
def _check_file_status(self) -> SourceStatus:
"""Check if local file is available."""
if self._local_file_path is None:
return SourceStatus(
source_type=DataSourceType.FILE,
available=False,
configured=False,
message="No local file path configured",
last_checked=datetime.now(),
)
if not self._local_file_path.exists():
return SourceStatus(
source_type=DataSourceType.FILE,
available=False,
configured=True,
message=f"File not found: {self._local_file_path}",
last_checked=datetime.now(),
)
size_mb = self._local_file_path.stat().st_size / (1024 * 1024)
return SourceStatus(
source_type=DataSourceType.FILE,
available=True,
configured=True,
message=f"Local file ready: {self._local_file_path.name} ({size_mb:.1f}MB)",
last_checked=datetime.now(),
)
def check_source_status(self, source_type: DataSourceType) -> SourceStatus:
"""
Check the status of a specific data source.
Args:
source_type: The type of source to check
Returns:
SourceStatus with current availability information
"""
if source_type == DataSourceType.CACHE:
return self._check_cache_status()
elif source_type == DataSourceType.SNOWFLAKE:
return self._check_snowflake_status()
elif source_type == DataSourceType.SQLITE:
return self._check_sqlite_status()
elif source_type == DataSourceType.FILE:
return self._check_file_status()
else:
return SourceStatus(
source_type=source_type,
available=False,
configured=False,
message=f"Unknown source type: {source_type}",
last_checked=datetime.now(),
)
def check_all_sources(self) -> list[SourceStatus]:
"""
Check the status of all data sources.
Returns:
List of SourceStatus for each source type
"""
statuses = []
for source_type in DataSourceType:
status = self.check_source_status(source_type)
self._source_status[source_type] = status
statuses.append(status)
return statuses
def _build_cache_key_params(
self,
start_date: Optional[date],
end_date: Optional[date],
trusts: Optional[list[str]],
drugs: Optional[list[str]],
directories: Optional[list[str]],
) -> tuple[str, tuple]:
"""Build a cache-compatible query string and params for the filter criteria."""
# Create a canonical representation for caching
query_parts = ["SELECT * FROM activity_data"]
params = []
conditions = []
if start_date:
conditions.append("start_date >= ?")
params.append(str(start_date))
if end_date:
conditions.append("end_date <= ?")
params.append(str(end_date))
if trusts:
placeholders = ",".join(["?"] * len(trusts))
conditions.append(f"trust IN ({placeholders})")
params.extend(sorted(trusts))
if drugs:
placeholders = ",".join(["?"] * len(drugs))
conditions.append(f"drug IN ({placeholders})")
params.extend(sorted(drugs))
if directories:
placeholders = ",".join(["?"] * len(directories))
conditions.append(f"directory IN ({placeholders})")
params.extend(sorted(directories))
if conditions:
query_parts.append("WHERE " + " AND ".join(conditions))
query = " ".join(query_parts)
return query, tuple(params)
def _try_cache(
self,
start_date: Optional[date],
end_date: Optional[date],
trusts: Optional[list[str]],
drugs: Optional[list[str]],
directories: Optional[list[str]],
) -> Optional[DataSourceResult]:
"""Try to get data from cache."""
if not self._cache_enabled:
return None
try:
from data_processing.cache import get_cache
cache = get_cache()
if not cache.is_enabled:
return None
query, params = self._build_cache_key_params(
start_date, end_date, trusts, drugs, directories
)
cached_data = cache.get(query, params)
if cached_data is None:
logger.debug("Cache miss")
return None
# Convert cached data back to DataFrame
df = pd.DataFrame(cached_data)
# Convert date columns
if 'Intervention Date' in df.columns:
df['Intervention Date'] = pd.to_datetime(df['Intervention Date'])
logger.info(f"Cache hit: {len(df)} rows")
return DataSourceResult(
df=df,
source_type=DataSourceType.CACHE,
source_detail=f"cache_key={query[:50]}...",
row_count=len(df),
cached=True,
from_fallback=False,
)
except Exception as e:
logger.warning(f"Cache lookup failed: {e}")
return None
def _try_snowflake(
self,
start_date: Optional[date],
end_date: Optional[date],
trusts: Optional[list[str]],
drugs: Optional[list[str]],
directories: Optional[list[str]],
progress_callback: Optional[Callable[[int, int], None]] = None,
) -> Optional[DataSourceResult]:
"""Try to get data from Snowflake."""
import time
try:
from data_processing.snowflake_connector import (
is_snowflake_available,
is_snowflake_configured,
get_connector,
SnowflakeConnectionError,
)
if not is_snowflake_available():
logger.debug("Snowflake connector not installed")
return None
if not is_snowflake_configured():
logger.debug("Snowflake not configured")
return None
# Get connector and fetch data
connector = get_connector()
logger.info("Fetching data from Snowflake...")
start_time = time.time()
# Fetch activity data from Snowflake
# Note: provider_codes filter not directly supported yet - would need trust name to code mapping
rows = connector.fetch_activity_data(
start_date=start_date,
end_date=end_date,
provider_codes=None, # TODO: map trust names to provider codes if needed
)
if not rows:
logger.warning("Snowflake returned no data")
return None
# Convert to DataFrame
df = pd.DataFrame(rows)
load_time = time.time() - start_time
logger.info(f"Snowflake loaded {len(df)} rows in {load_time:.2f}s")
# Apply local transformations to match expected format
# (patient_id, drug_names, department_identification)
from data_processing.transforms import patient_id, drug_names, department_identification
from core import default_paths
df = patient_id(df)
df = drug_names(df, paths=default_paths)
df = department_identification(df, paths=default_paths)
# Apply additional filters if provided
if trusts and 'OrganisationName' in df.columns:
df = df[df['OrganisationName'].isin(trusts)]
if drugs and 'Drug Name' in df.columns:
df = df[df['Drug Name'].isin(drugs)]
if directories and 'Directory' in df.columns:
df = df[df['Directory'].isin(directories)]
return DataSourceResult(
df=df,
source_type=DataSourceType.SNOWFLAKE,
source_detail="DATA_HUB.CDM.Acute__Conmon__PatientLevelDrugs",
row_count=len(df),
cached=False,
from_fallback=False,
load_time_seconds=load_time,
)
except Exception as e:
logger.warning(f"Snowflake query failed: {e}")
return None
def _try_sqlite(
self,
start_date: Optional[date],
end_date: Optional[date],
trusts: Optional[list[str]],
drugs: Optional[list[str]],
directories: Optional[list[str]],
) -> Optional[DataSourceResult]:
"""Try to get data from SQLite.
Note: Raw intervention data is no longer stored in SQLite.
The app now uses pre-computed pathway_nodes via load_pathway_data().
This fallback is retained for interface compatibility but always returns None.
"""
logger.debug("SQLite raw data fallback skipped (fact_interventions removed)")
return None
def _try_file(
self,
start_date: Optional[date],
end_date: Optional[date],
trusts: Optional[list[str]],
drugs: Optional[list[str]],
directories: Optional[list[str]],
) -> Optional[DataSourceResult]:
"""Try to get data from local file."""
import time
if self._local_file_path is None:
logger.debug("No local file configured")
return None
try:
from data_processing.loader import FileDataLoader
loader = FileDataLoader(file_path=self._local_file_path)
is_valid, msg = loader.validate_source()
if not is_valid:
logger.debug(f"Local file not available: {msg}")
return None
start_time = time.time()
result = loader.load()
df = result.df
# Apply filters (file loader loads all data, then we filter)
if start_date and 'Intervention Date' in df.columns:
df = df[df['Intervention Date'] >= pd.Timestamp(start_date)]
if end_date and 'Intervention Date' in df.columns:
df = df[df['Intervention Date'] < pd.Timestamp(end_date)]
if trusts and 'OrganisationName' in df.columns:
df = df[df['OrganisationName'].isin(trusts)]
if drugs and 'Drug Name' in df.columns:
df = df[df['Drug Name'].isin(drugs)]
if directories and 'Directory' in df.columns:
df = df[df['Directory'].isin(directories)]
load_time = time.time() - start_time
logger.info(f"File loaded and filtered: {len(df)} rows in {load_time:.2f}s")
return DataSourceResult(
df=df,
source_type=DataSourceType.FILE,
source_detail=str(self._local_file_path),
row_count=len(df),
cached=False,
from_fallback=True,
load_time_seconds=load_time,
)
except Exception as e:
logger.warning(f"File load failed: {e}")
return None
def get_data(
self,
start_date: Optional[date] = None,
end_date: Optional[date] = None,
trusts: Optional[list[str]] = None,
drugs: Optional[list[str]] = None,
directories: Optional[list[str]] = None,
preferred_source: Optional[str] = None,
skip_cache: bool = False,
progress_callback: Optional[Callable[[int, int], None]] = None,
) -> DataSourceResult:
"""
Get patient intervention data from the best available source.
The fallback chain is: Cache → Snowflake → SQLite → File
Args:
start_date: Optional start date for filtering (inclusive)
end_date: Optional end date for filtering (exclusive)
trusts: Optional list of trust names to filter
drugs: Optional list of drug names to filter
directories: Optional list of directories to filter
preferred_source: Optional preferred source ("snowflake", "sqlite", "file")
skip_cache: If True, bypass cache and query source directly
progress_callback: Optional callback(current, total) for progress updates
Returns:
DataSourceResult with the loaded data and metadata
Raises:
ValueError: If no data source is available or all sources fail
"""
import time
start_time = time.time()
warnings = []
# If preferred source specified, try that first
if preferred_source:
preferred = preferred_source.lower()
if preferred == "snowflake":
result = self._try_snowflake(
start_date, end_date, trusts, drugs, directories, progress_callback
)
if result:
result.load_time_seconds = time.time() - start_time
return result
warnings.append("Preferred source 'snowflake' unavailable")
elif preferred == "sqlite":
result = self._try_sqlite(
start_date, end_date, trusts, drugs, directories
)
if result:
result.load_time_seconds = time.time() - start_time
return result
warnings.append("Preferred source 'sqlite' unavailable")
elif preferred == "file":
result = self._try_file(
start_date, end_date, trusts, drugs, directories
)
if result:
result.load_time_seconds = time.time() - start_time
return result
warnings.append("Preferred source 'file' unavailable")
# Standard fallback chain: cache → snowflake → sqlite → file
# 1. Try cache first (unless skipped)
if not skip_cache:
result = self._try_cache(
start_date, end_date, trusts, drugs, directories
)
if result:
result.load_time_seconds = time.time() - start_time
return result
# 2. Try Snowflake
result = self._try_snowflake(
start_date, end_date, trusts, drugs, directories, progress_callback
)
if result:
# Cache the result for future queries
if self._cache_enabled:
self._cache_result(
result.df,
start_date, end_date, trusts, drugs, directories,
includes_current_data=end_date is None or end_date >= date.today()
)
result.load_time_seconds = time.time() - start_time
return result
# 3. Try SQLite
result = self._try_sqlite(
start_date, end_date, trusts, drugs, directories
)
if result:
result.from_fallback = True # Mark as fallback since Snowflake wasn't used
result.load_time_seconds = time.time() - start_time
if warnings:
result.warnings.extend(warnings)
return result
# 4. Try local file
result = self._try_file(
start_date, end_date, trusts, drugs, directories
)
if result:
result.from_fallback = True
result.load_time_seconds = time.time() - start_time
if warnings:
result.warnings.extend(warnings)
return result
# All sources failed
source_status = self.check_all_sources()
status_msg = "; ".join(
f"{s.source_type.value}: {s.message}" for s in source_status
)
raise ValueError(f"No data source available. Status: {status_msg}")
def _cache_result(
self,
df: pd.DataFrame,
start_date: Optional[date],
end_date: Optional[date],
trusts: Optional[list[str]],
drugs: Optional[list[str]],
directories: Optional[list[str]],
includes_current_data: bool = False,
) -> bool:
"""Cache a query result for future use."""
try:
from data_processing.cache import get_cache
cache = get_cache()
if not cache.is_enabled:
return False
query, params = self._build_cache_key_params(
start_date, end_date, trusts, drugs, directories
)
# Convert DataFrame to list of dicts for caching
# Convert datetime columns to strings for JSON serialization
df_copy = df.copy()
for col in df_copy.columns:
if pd.api.types.is_datetime64_any_dtype(df_copy[col]):
df_copy[col] = df_copy[col].astype(str)
data = df_copy.to_dict(orient='records')
entry = cache.set(
query, params, data,
includes_current_data=includes_current_data
)
if entry:
logger.info(f"Cached {len(data)} rows (key={entry.cache_key[:16]}...)")
return True
return False
except Exception as e:
logger.warning(f"Failed to cache result: {e}")
return False
def clear_cache(self) -> int:
"""
Clear all cached data.
Returns:
Number of cache entries cleared
"""
try:
from data_processing.cache import get_cache
cache = get_cache()
return cache.clear()
except Exception as e:
logger.warning(f"Failed to clear cache: {e}")
return 0
def refresh_from_snowflake(
self,
start_date: Optional[date] = None,
end_date: Optional[date] = None,
trusts: Optional[list[str]] = None,
drugs: Optional[list[str]] = None,
directories: Optional[list[str]] = None,
progress_callback: Optional[Callable[[int, int], None]] = None,
) -> DataSourceResult:
"""
Force a refresh from Snowflake, bypassing cache and other sources.
This method specifically queries Snowflake and will fail if Snowflake
is not available or not configured.
Args:
start_date: Optional start date for filtering
end_date: Optional end date for filtering
trusts: Optional list of trust names
drugs: Optional list of drug names
directories: Optional list of directories
progress_callback: Optional progress callback
Returns:
DataSourceResult from Snowflake
Raises:
ValueError: If Snowflake is not available or query fails
"""
from data_processing.snowflake_connector import (
is_snowflake_available,
is_snowflake_configured,
)
if not is_snowflake_available():
raise ValueError("Snowflake connector not installed")
if not is_snowflake_configured():
raise ValueError("Snowflake not configured - edit config/snowflake.toml")
result = self._try_snowflake(
start_date, end_date, trusts, drugs, directories, progress_callback
)
if result is None:
raise ValueError("Snowflake query failed - check logs for details")
# Cache the fresh result
if self._cache_enabled:
self._cache_result(
result.df,
start_date, end_date, trusts, drugs, directories,
includes_current_data=end_date is None or end_date >= date.today()
)
return result
# Module-level singleton and convenience functions
_default_manager: Optional[DataSourceManager] = None
def get_data_source_manager(
cache_enabled: bool = True,
local_file_path: Optional[Path | str] = None,
sqlite_db_path: Optional[Path | str] = None,
) -> DataSourceManager:
"""
Get a DataSourceManager instance.
Args:
cache_enabled: Whether to enable caching
local_file_path: Optional path to local CSV/Parquet file
sqlite_db_path: Optional path to SQLite database
Returns:
DataSourceManager instance
"""
global _default_manager
# If custom paths provided, create a new manager
if local_file_path or sqlite_db_path:
return DataSourceManager(
cache_enabled=cache_enabled,
local_file_path=local_file_path,
sqlite_db_path=sqlite_db_path,
)
# Otherwise use/create singleton
if _default_manager is None:
_default_manager = DataSourceManager(cache_enabled=cache_enabled)
return _default_manager
def get_data(
start_date: Optional[date] = None,
end_date: Optional[date] = None,
trusts: Optional[list[str]] = None,
drugs: Optional[list[str]] = None,
directories: Optional[list[str]] = None,
preferred_source: Optional[str] = None,
skip_cache: bool = False,
) -> DataSourceResult:
"""
Convenience function to get data using the default manager.
Args:
start_date: Optional start date for filtering
end_date: Optional end date for filtering
trusts: Optional list of trust names
drugs: Optional list of drug names
directories: Optional list of directories
preferred_source: Optional preferred source
skip_cache: If True, bypass cache
Returns:
DataSourceResult with loaded data
"""
manager = get_data_source_manager()
return manager.get_data(
start_date=start_date,
end_date=end_date,
trusts=trusts,
drugs=drugs,
directories=directories,
preferred_source=preferred_source,
skip_cache=skip_cache,
)
def reset_data_source_manager() -> None:
"""Reset the default data source manager singleton."""
global _default_manager
_default_manager = None
# Export public API
__all__ = [
"DataSourceType",
"DataSourceResult",
"SourceStatus",
"DataSourceManager",
"get_data_source_manager",
"get_data",
"reset_data_source_manager",
]
+239
View File
@@ -0,0 +1,239 @@
"""
SQLite database connection management for NHS High-Cost Drug Patient Pathway Analysis Tool.
Provides connection management, schema initialization, and common database operations.
Uses context manager pattern for safe resource handling.
"""
import sqlite3
from contextlib import contextmanager
from pathlib import Path
from typing import Optional, Generator, Literal
from core.logging_config import get_logger
logger = get_logger(__name__)
class DatabaseConfig:
"""
Configuration for SQLite database location and connection parameters.
Attributes:
db_path: Path to the SQLite database file
timeout: Connection timeout in seconds (default: 30)
isolation_level: Transaction isolation level (default: None for autocommit)
"""
DEFAULT_DB_NAME = "pathways.db"
def __init__(
self,
db_path: Optional[Path] = None,
data_dir: Optional[Path] = None,
timeout: float = 30.0,
isolation_level: Optional[Literal['DEFERRED', 'EXCLUSIVE', 'IMMEDIATE']] = None
):
"""
Initialize database configuration.
Args:
db_path: Full path to database file. If None, uses data_dir/DEFAULT_DB_NAME.
data_dir: Directory to place database in. Defaults to ./data/
timeout: Connection timeout in seconds.
isolation_level: Transaction isolation level. None = autocommit.
"""
if db_path is not None:
self.db_path = Path(db_path)
elif data_dir is not None:
self.db_path = Path(data_dir) / self.DEFAULT_DB_NAME
else:
self.db_path = Path("./data") / self.DEFAULT_DB_NAME
self.timeout = timeout
self.isolation_level = isolation_level
def validate(self) -> list[str]:
"""
Validate database configuration.
Returns:
List of error messages. Empty list means configuration is valid.
"""
errors = []
# Check parent directory exists
parent_dir = self.db_path.parent
if not parent_dir.exists():
errors.append(f"Database directory does not exist: {parent_dir}")
return errors
class DatabaseManager:
"""
Manages SQLite database connections and operations.
Provides context manager for safe connection handling and methods
for common database operations.
Usage:
db_manager = DatabaseManager()
# Using context manager (recommended)
with db_manager.get_connection() as conn:
cursor = conn.execute("SELECT * FROM ref_drug_names")
results = cursor.fetchall()
# Or get a managed connection for longer operations
conn = db_manager.connect()
try:
# ... do work ...
finally:
conn.close()
"""
def __init__(self, config: Optional[DatabaseConfig] = None):
"""
Initialize the database manager.
Args:
config: Database configuration. If None, uses default configuration.
"""
self.config = config or DatabaseConfig()
self._connection: Optional[sqlite3.Connection] = None
@property
def db_path(self) -> Path:
"""Path to the SQLite database file."""
return self.config.db_path
@property
def exists(self) -> bool:
"""Check if the database file exists."""
return self.db_path.exists()
def connect(self) -> sqlite3.Connection:
"""
Create a new database connection.
Returns:
sqlite3.Connection: New database connection.
Note:
The caller is responsible for closing the connection.
Consider using get_connection() context manager instead.
"""
conn = sqlite3.connect(
str(self.db_path),
timeout=self.config.timeout,
isolation_level=self.config.isolation_level
)
# Enable foreign key support
conn.execute("PRAGMA foreign_keys = ON")
# Return rows as sqlite3.Row for dict-like access
conn.row_factory = sqlite3.Row
return conn
@contextmanager
def get_connection(self) -> Generator[sqlite3.Connection, None, None]:
"""
Context manager for database connections.
Yields:
sqlite3.Connection: Database connection.
Example:
with db_manager.get_connection() as conn:
conn.execute("INSERT INTO table VALUES (?)", (value,))
conn.commit()
"""
conn = self.connect()
try:
yield conn
except Exception:
conn.rollback()
raise
finally:
conn.close()
@contextmanager
def get_transaction(self) -> Generator[sqlite3.Connection, None, None]:
"""
Context manager for transactional operations.
Automatically commits on success, rolls back on exception.
Yields:
sqlite3.Connection: Database connection in transaction mode.
Example:
with db_manager.get_transaction() as conn:
conn.execute("INSERT INTO table VALUES (?)", (value1,))
conn.execute("INSERT INTO other_table VALUES (?)", (value2,))
# Auto-commits if no exception
"""
conn = sqlite3.connect(
str(self.db_path),
timeout=self.config.timeout,
isolation_level="DEFERRED" # Explicit transaction mode
)
conn.execute("PRAGMA foreign_keys = ON")
conn.row_factory = sqlite3.Row
try:
yield conn
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
def execute_script(self, sql_script: str) -> None:
"""
Execute a SQL script (multiple statements).
Args:
sql_script: SQL script containing one or more statements.
"""
with self.get_connection() as conn:
conn.executescript(sql_script)
logger.info("Executed SQL script successfully")
def table_exists(self, table_name: str) -> bool:
"""
Check if a table exists in the database.
Args:
table_name: Name of the table to check.
Returns:
True if the table exists, False otherwise.
"""
with self.get_connection() as conn:
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
(table_name,)
)
return cursor.fetchone() is not None
def get_table_count(self, table_name: str) -> int:
"""
Get the row count for a table.
Args:
table_name: Name of the table.
Returns:
Number of rows in the table.
"""
with self.get_connection() as conn:
# Use parameterized table name via string formatting (safe since we control table_name)
cursor = conn.execute(f"SELECT COUNT(*) FROM {table_name}")
result = cursor.fetchone()
return result[0] if result else 0
# Default instance for application-wide use
default_db_config = DatabaseConfig()
default_db_manager = DatabaseManager(default_db_config)
File diff suppressed because it is too large Load Diff
+246
View File
@@ -0,0 +1,246 @@
"""
Data loader abstractions for NHS High-Cost Drug Patient Pathway Analysis Tool.
Provides a unified interface for loading patient intervention data from:
- CSV/Parquet files (current behavior)
- SQLite database (new, faster approach)
- Snowflake (future, direct from warehouse)
The DataLoader ABC defines the contract for all loader implementations.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
import pandas as pd
from core import PathConfig, default_paths
from core.logging_config import get_logger
logger = get_logger(__name__)
@dataclass
class LoadResult:
"""Result of a data load operation.
Attributes:
df: The loaded DataFrame with processed patient intervention data
source: Description of the data source (e.g., "file:/path/to/file.csv")
row_count: Number of rows loaded
columns: List of column names in the DataFrame
load_time_seconds: Time taken to load the data
"""
df: pd.DataFrame
source: str
row_count: int
columns: list[str] = field(default_factory=list)
load_time_seconds: float = 0.0
def __post_init__(self):
if not self.columns:
self.columns = list(self.df.columns)
# Expected columns in a processed DataFrame
# These are the columns that generate_graph() expects to receive
REQUIRED_COLUMNS = [
"UPID", # Unique Patient ID (Provider Code prefix + PersonKey)
"Drug Name", # Standardized drug name
"Intervention Date", # Date of intervention
"Price Actual", # Cost of intervention
"OrganisationName", # NHS Trust name
"Directory", # Medical specialty/directory
"Provider Code", # NHS provider code
"PersonKey", # Patient identifier within provider
]
# Additional columns that are useful but not strictly required
OPTIONAL_COLUMNS = [
"UPIDTreatment", # UPID + Drug Name combo (created by generate_graph)
"Treatment Function Code", # NHS treatment function code
"Additional Detail 1",
"Additional Detail 2",
"Additional Detail 3",
"Additional Detail 4",
"Additional Detail 5",
]
class DataLoader(ABC):
"""Abstract base class for data loaders.
All data loaders must implement the load() method which returns
a DataFrame ready for use by generate_graph().
The returned DataFrame must contain REQUIRED_COLUMNS at minimum.
"""
@abstractmethod
def load(self) -> LoadResult:
"""Load and process patient intervention data.
Returns:
LoadResult containing the processed DataFrame and metadata.
The DataFrame must contain all REQUIRED_COLUMNS.
Raises:
FileNotFoundError: If the data source doesn't exist
ValueError: If the data is malformed or missing required columns
"""
pass
@abstractmethod
def validate_source(self) -> tuple[bool, str]:
"""Check if the data source is valid and accessible.
Returns:
Tuple of (is_valid, message).
If is_valid is False, message explains the issue.
"""
pass
@property
@abstractmethod
def source_description(self) -> str:
"""Human-readable description of the data source."""
pass
def validate_dataframe(self, df: pd.DataFrame) -> tuple[bool, list[str]]:
"""Validate that a DataFrame has all required columns.
Args:
df: DataFrame to validate
Returns:
Tuple of (is_valid, missing_columns).
If is_valid is False, missing_columns lists what's missing.
"""
missing = [col for col in REQUIRED_COLUMNS if col not in df.columns]
return len(missing) == 0, missing
class FileDataLoader(DataLoader):
"""Loads data from CSV or Parquet files.
This replicates the current behavior of dashboard_gui.main():
1. Read CSV or Parquet file
2. Apply patient_id() transformation
3. Convert dates
4. Apply drug_names() standardization
5. Clean organization names
6. Apply department_identification()
Args:
file_path: Path to the CSV or Parquet file
paths: PathConfig for reference data file locations (uses default_paths if None)
"""
def __init__(
self,
file_path: Path | str,
paths: Optional[PathConfig] = None,
):
self.file_path = Path(file_path)
self.paths = paths or default_paths
def validate_source(self) -> tuple[bool, str]:
"""Check if the file exists and has a supported extension."""
if not self.file_path.exists():
return False, f"File not found: {self.file_path}"
ext = self.file_path.suffix.lower()
if ext not in ('.csv', '.parquet'):
return False, f"Unsupported file type: {ext}. Must be .csv or .parquet"
return True, "OK"
@property
def source_description(self) -> str:
return f"file:{self.file_path}"
def load(self) -> LoadResult:
"""Load and process data from CSV or Parquet file.
Applies the same transformation pipeline as the original
dashboard_gui.main() function.
"""
import time
from data_processing import transforms as data
start_time = time.time()
# Validate source before loading
is_valid, msg = self.validate_source()
if not is_valid:
raise FileNotFoundError(msg)
# Read file based on extension
ext = self.file_path.suffix.lower()
logger.info(f"Reading {ext} file: {self.file_path}")
if ext == '.csv':
df_raw = pd.read_csv(self.file_path, low_memory=False)
else: # .parquet
df_raw = pd.read_parquet(self.file_path)
logger.info(f"File read successfully. {len(df_raw)} rows.")
# Apply transformations (same as dashboard_gui.main())
df = data.patient_id(df_raw)
logger.info("Patient ID processing complete.")
df['Intervention Date'] = pd.to_datetime(df['Intervention Date'], format="%Y-%m-%d")
logger.info("Date conversion complete.")
# Preserve original drug name before standardization (for SQLite storage)
df['Drug Name Raw'] = df['Drug Name'].copy()
df = data.drug_names(df, self.paths)
logger.info("Drug name processing complete.")
df['OrganisationName'] = df['OrganisationName'].str.replace(',', '')
logger.info("Organisation name cleaning complete.")
df = data.department_identification(df, self.paths)
logger.info("Department identification complete.")
# Validate result
is_valid, missing = self.validate_dataframe(df)
if not is_valid:
raise ValueError(f"Processed DataFrame missing required columns: {missing}")
load_time = time.time() - start_time
logger.info(f"Data loading complete. {len(df)} rows in {load_time:.2f}s")
return LoadResult(
df=df,
source=self.source_description,
row_count=len(df),
load_time_seconds=load_time,
)
def get_loader(
source: str | Path,
paths: Optional[PathConfig] = None,
**kwargs
) -> DataLoader:
"""Factory function to create the appropriate DataLoader.
Args:
source: File path (CSV/Parquet)
paths: PathConfig for reference data (used by FileDataLoader)
**kwargs: Additional arguments passed to the loader constructor
Returns:
Appropriate DataLoader instance
Examples:
>>> loader = get_loader("data/activity.csv")
>>> loader = get_loader("data/activity.parquet")
"""
path = Path(source)
return FileDataLoader(file_path=path, paths=paths)
+469
View File
@@ -0,0 +1,469 @@
"""
Database migration script for NHS High-Cost Drug Patient Pathway Analysis Tool.
Provides functions to initialize the SQLite database schema and CLI interface
for running migrations from the command line.
Usage:
# Initialize database (creates all tables)
python -m data_processing.migrate
# Drop existing tables and reinitialize
python -m data_processing.migrate --drop-existing
# Show current database status
python -m data_processing.migrate --status
# Migrate all reference data from CSV files
python -m data_processing.migrate --reference-data
# Migrate reference data with verification
python -m data_processing.migrate --reference-data --verify
"""
import argparse
import sys
from pathlib import Path
from typing import Optional
# Ensure src/ is on sys.path when run as `python -m data_processing.migrate`
_src_dir = str(Path(__file__).resolve().parent.parent)
if _src_dir not in sys.path:
sys.path.insert(0, _src_dir)
from core.logging_config import setup_logging, get_logger
from data_processing.database import DatabaseManager, DatabaseConfig
from core import PathConfig, default_paths
from data_processing.schema import (
create_all_tables,
drop_all_tables,
verify_all_tables_exist,
get_all_table_counts,
migrate_pathway_nodes_chart_type,
migrate_refresh_log_source_row_count,
)
from data_processing.reference_data import (
MigrationResult,
migrate_drug_names,
migrate_organizations,
migrate_directories,
migrate_drug_directory_map,
migrate_drug_indication_clusters,
verify_drug_names_migration,
verify_organizations_migration,
verify_directories_migration,
verify_drug_directory_map_migration,
verify_drug_indication_clusters_migration,
)
logger = get_logger(__name__)
def initialize_database(
db_manager: Optional[DatabaseManager] = None,
drop_existing: bool = False,
confirm_drop: bool = True
) -> bool:
"""
Initialize the database with all required tables.
Creates all tables defined in the schema (reference tables and pathway
tables). Uses IF NOT EXISTS so safe to run multiple times.
Args:
db_manager: DatabaseManager instance. Uses default if not provided.
drop_existing: If True, drops all existing tables before creating.
confirm_drop: If True and drop_existing=True, prompts for confirmation.
Set to False for non-interactive use.
Returns:
True if initialization succeeded, False otherwise.
"""
if db_manager is None:
db_manager = DatabaseManager()
logger.info(f"Initializing database at: {db_manager.db_path}")
# Handle drop existing with confirmation
if drop_existing:
if confirm_drop:
print(f"\nWARNING: This will delete ALL data from the database:")
print(f" {db_manager.db_path}\n")
response = input("Are you sure you want to continue? (yes/no): ")
if response.lower() not in ("yes", "y"):
print("Operation cancelled.")
return False
if db_manager.exists:
logger.warning("Dropping existing tables...")
with db_manager.get_connection() as conn:
drop_all_tables(conn)
conn.commit()
logger.info("Existing tables dropped")
else:
logger.info("Database does not exist yet, nothing to drop")
# Create all tables
try:
with db_manager.get_transaction() as conn:
create_all_tables(conn)
except Exception as e:
logger.error(f"Failed to create tables: {e}")
return False
# Run migrations for schema changes
try:
with db_manager.get_connection() as conn:
# Add chart_type column to pathway_nodes if it doesn't exist
success, msg = migrate_pathway_nodes_chart_type(conn)
if success:
logger.info(f"pathway_nodes migration: {msg}")
else:
logger.error(f"pathway_nodes migration failed: {msg}")
return False
# Add source_row_count column to pathway_refresh_log if it doesn't exist
success, msg = migrate_refresh_log_source_row_count(conn)
if success:
logger.info(f"pathway_refresh_log migration: {msg}")
else:
logger.error(f"pathway_refresh_log migration failed: {msg}")
return False
except Exception as e:
logger.error(f"Migration failed: {e}")
return False
# Verify all tables were created
with db_manager.get_connection() as conn:
missing = verify_all_tables_exist(conn)
if missing:
logger.error(f"Table creation failed. Missing tables: {missing}")
return False
logger.info("All tables created successfully")
return True
def migrate_all_reference_data(
db_manager: Optional[DatabaseManager] = None,
paths: Optional[PathConfig] = None,
verify: bool = False
) -> tuple[bool, list[MigrationResult]]:
"""
Run all reference data migrations from CSV files to SQLite tables.
Migrations are run in order:
1. Drug names (drugnames.csv → ref_drug_names)
2. Organizations (org_codes.csv → ref_organizations)
3. Directories (directory_list.csv → ref_directories)
4. Drug-directory mappings (drug_directory_list.csv → ref_drug_directory_map)
Args:
db_manager: DatabaseManager instance. Uses default if not provided.
paths: PathConfig instance for locating CSV files. Uses default if not provided.
verify: If True, runs verification after each migration.
Returns:
Tuple of (all_success: bool, results: list of MigrationResult)
"""
if db_manager is None:
db_manager = DatabaseManager()
if paths is None:
paths = default_paths
results: list[MigrationResult] = []
all_success = True
# Define migrations in order
# Note: drug_indication_clusters uses a different signature (csv_path instead of paths)
migrations = [
("Drug names", migrate_drug_names, verify_drug_names_migration if verify else None, True),
("Organizations", migrate_organizations, verify_organizations_migration if verify else None, True),
("Directories", migrate_directories, verify_directories_migration if verify else None, True),
("Drug-directory map", migrate_drug_directory_map, verify_drug_directory_map_migration if verify else None, True),
("Drug indication clusters", migrate_drug_indication_clusters, verify_drug_indication_clusters_migration if verify else None, False),
]
logger.info(f"Starting reference data migrations ({len(migrations)} tables)")
for name, migrate_fn, verify_fn, uses_paths in migrations:
logger.info(f"Migrating: {name}...")
# Run migration (some use paths parameter, some use csv_path)
if uses_paths:
result = migrate_fn(db_manager=db_manager, paths=paths) # type: ignore[operator]
else:
# Drug indication clusters uses csv_path instead of paths
result = migrate_fn(db_manager=db_manager) # type: ignore[operator]
results.append(result)
if not result.success:
logger.error(f"Migration failed: {name} - {result.error_message}")
all_success = False
continue
logger.info(f" {result}")
# Run verification if requested
if verify_fn is not None:
logger.info(f" Verifying {name}...")
if uses_paths:
verified, verify_msg = verify_fn(db_manager=db_manager, paths=paths) # type: ignore[call-arg]
else:
verified, verify_msg = verify_fn(db_manager=db_manager) # type: ignore[call-arg]
if verified:
logger.info(f" OK: {verify_msg}")
else:
logger.error(f" FAILED: Verification failed: {verify_msg}")
all_success = False
# Summary
successful = sum(1 for r in results if r.success)
logger.info(f"Reference data migrations complete: {successful}/{len(results)} succeeded")
return all_success, results
def print_migration_summary(results: list[MigrationResult]) -> None:
"""Print a summary of migration results to stdout."""
print("\n=== Reference Data Migration Summary ===\n")
for result in results:
status = "[OK]" if result.success else "[FAILED]"
print(f"{status} {result.table_name}")
if result.success:
print(f" Read: {result.rows_read}, Inserted: {result.rows_inserted}, Skipped: {result.rows_skipped}")
else:
print(f" Error: {result.error_message}")
successful = sum(1 for r in results if r.success)
print(f"\nTotal: {successful}/{len(results)} migrations succeeded")
print()
def create_progress_reporter(description: str = "Loading", width: int = 40):
"""
Create a progress callback that prints a progress bar to stdout.
Args:
description: Label to show before the progress bar.
width: Width of the progress bar in characters.
Returns:
Callback function(current, total) that prints progress.
"""
last_percent = [-1] # Use list to allow mutation in closure
def report_progress(current: int, total: int) -> None:
"""Print a progress bar showing current/total progress."""
if total == 0:
percent = 100
else:
percent = int(100 * current / total)
# Only update display when percentage changes (avoid excessive output)
if percent == last_percent[0]:
return
last_percent[0] = percent
filled = int(width * current / total) if total > 0 else width
bar = "=" * filled + "-" * (width - filled)
# Use carriage return to overwrite the line
sys.stdout.write(f"\r{description}: [{bar}] {percent:3d}% ({current:,}/{total:,})")
sys.stdout.flush()
# Print newline when complete
if current >= total:
print()
return report_progress
def get_database_status(db_manager: Optional[DatabaseManager] = None) -> dict:
"""
Get the current status of the database.
Returns:
Dictionary with database status information:
- exists: Whether the database file exists
- path: Path to the database file
- size_bytes: Size of database file (if exists)
- tables: Dictionary of table names to row counts
- missing_tables: List of expected tables that don't exist
"""
if db_manager is None:
db_manager = DatabaseManager()
status = {
"exists": db_manager.exists,
"path": str(db_manager.db_path),
"size_bytes": None,
"tables": {},
"missing_tables": [],
}
if db_manager.exists:
status["size_bytes"] = db_manager.db_path.stat().st_size
with db_manager.get_connection() as conn:
status["missing_tables"] = verify_all_tables_exist(conn)
# Get counts for existing tables
try:
status["tables"] = get_all_table_counts(conn)
except Exception as e:
logger.warning(f"Could not get table counts: {e}")
return status
def print_database_status(db_manager: Optional[DatabaseManager] = None) -> None:
"""Print database status to stdout in a human-readable format."""
status = get_database_status(db_manager)
print("\n=== Database Status ===\n")
print(f"Path: {status['path']}")
print(f"Exists: {status['exists']}")
if status["exists"]:
size_kb = (status["size_bytes"] or 0) / 1024
print(f"Size: {size_kb:.1f} KB")
if status["missing_tables"]:
print(f"\nMissing tables: {', '.join(status['missing_tables'])}")
else:
print("\nAll expected tables exist.")
if status["tables"]:
print("\nTable row counts:")
for table, count in sorted(status["tables"].items()):
print(f" {table}: {count:,} rows")
else:
print("\nDatabase does not exist. Run migration to create it.")
print()
def main():
"""CLI entry point for database migration."""
parser = argparse.ArgumentParser(
description="Initialize NHS Pathways Analysis SQLite database schema",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python -m data_processing.migrate # Initialize database
python -m data_processing.migrate --status # Show database status
python -m data_processing.migrate --drop-existing # Reset database
python -m data_processing.migrate --reference-data # Migrate reference data
python -m data_processing.migrate --reference-data --verify # With verification
python -m data_processing.migrate --db-path ./data/test.db # Custom path
"""
)
parser.add_argument(
"--status",
action="store_true",
help="Show current database status and exit"
)
parser.add_argument(
"--drop-existing",
action="store_true",
help="Drop all existing tables before creating (WARNING: deletes data)"
)
parser.add_argument(
"--reference-data",
action="store_true",
help="Migrate all reference data from CSV files to SQLite tables"
)
parser.add_argument(
"--verify",
action="store_true",
help="Verify migrated data matches CSV sources (use with --reference-data)"
)
parser.add_argument(
"--db-path",
type=Path,
help="Path to database file (default: ./data/pathways.db)"
)
parser.add_argument(
"--yes", "-y",
action="store_true",
help="Skip confirmation prompts (for non-interactive use)"
)
parser.add_argument(
"--verbose", "-v",
action="store_true",
help="Enable verbose logging"
)
args = parser.parse_args()
# Set up logging
log_level = "DEBUG" if args.verbose else "INFO"
setup_logging(level=log_level, simple_console=True)
# Create database manager with optional custom path
if args.db_path:
config = DatabaseConfig(db_path=args.db_path)
db_manager = DatabaseManager(config)
else:
db_manager = DatabaseManager()
# Handle --status
if args.status:
print_database_status(db_manager)
return 0
# Validate configuration
config_errors = db_manager.config.validate()
if config_errors:
for error in config_errors:
logger.error(error)
return 1
# Handle --reference-data (migrate reference data from CSV to SQLite)
if args.reference_data:
# Ensure database exists with tables first
if not db_manager.exists:
print("Database does not exist. Initializing schema first...")
success = initialize_database(db_manager=db_manager)
if not success:
print("\nDatabase initialization failed. Check logs for details.")
return 1
# Run reference data migrations
success, results = migrate_all_reference_data(
db_manager=db_manager,
paths=default_paths,
verify=args.verify
)
print_migration_summary(results)
print_database_status(db_manager)
if success:
print("Reference data migration completed successfully.")
return 0
else:
print("Reference data migration completed with errors. Check logs for details.")
return 1
# Run schema migration (default behavior)
success = initialize_database(
db_manager=db_manager,
drop_existing=args.drop_existing,
confirm_drop=not args.yes
)
if success:
print("\nDatabase initialized successfully.")
print_database_status(db_manager)
return 0
else:
print("\nDatabase initialization failed. Check logs for details.")
return 1
if __name__ == "__main__":
sys.exit(main())
+122
View File
@@ -0,0 +1,122 @@
"""Parsing utilities for pathway node data.
Shared functions for extracting structured data from pathway_nodes columns.
Used by analytics chart callbacks in dash_app/callbacks/.
"""
import re
def parse_average_spacing(spacing_html):
"""Extract dosing information from average_spacing HTML string.
Args:
spacing_html: HTML like '<br><b>DRUG</b><br>On average given 35.6 times
with a 9.0 weekly interval (320.0 weeks total treatment length)'
May contain multiple drug entries separated by <br><b>.
Returns:
List of dicts with keys: drug_name, dose_count, weekly_interval, total_weeks.
Returns empty list for None/empty input or unparseable strings.
"""
if not spacing_html:
return []
results = []
pattern = (
r"<b>([^<]+)</b><br>"
r"On average given ([\d.]+) times "
r"with a ([\d.]+) weekly interval "
r"\(([\d.]+) weeks total treatment length\)"
)
for match in re.finditer(pattern, spacing_html):
results.append({
"drug_name": match.group(1).strip(),
"dose_count": float(match.group(2)),
"weekly_interval": float(match.group(3)),
"total_weeks": float(match.group(4)),
})
return results
def parse_pathway_drugs(ids, level):
"""Extract ordered drug list from the ids column at level 4+.
Args:
ids: String like 'ROOT - TRUST - DIR - DRUG_A - DRUG_B - DRUG_C'.
Segments are separated by ' - '. Drug names start at index 3
(0=root, 1=trust, 2=directory, 3+=drugs).
level: Node level. Only meaningful for level >= 3.
Returns:
List of drug names in treatment order. Empty list for level < 3
or invalid input.
"""
if not ids or level < 3:
return []
segments = ids.split(" - ")
# Segments: [root, trust, directory, drug_0, drug_1, ...]
if len(segments) <= 3:
return []
return segments[3:]
def _get_patients(node):
"""Get patient count from a node dict (supports both 'value' and 'patients' keys)."""
return node.get("value") or node.get("patients") or 0
def calculate_retention_rate(nodes):
"""Calculate pathway retention rates from node data.
For each N-drug pathway, calculate what % of patients do NOT escalate
to an N+1 drug pathway. This identifies effective treatment sequences.
Args:
nodes: List of dicts with 'ids', 'level', and 'value' or 'patients' keys.
Should contain level 4+ nodes (pathway level).
Returns:
Dict mapping pathway ids to retention info:
{ids: {"retained_patients": int, "total_patients": int,
"retention_rate": float, "drug_sequence": list}}
"""
if not nodes:
return {}
# Index nodes by ids for parent lookup
node_map = {n["ids"]: n for n in nodes if n.get("ids")}
results = {}
for node in nodes:
level = node.get("level", 0)
if level < 4:
continue
node_ids = node.get("ids", "")
total_patients = _get_patients(node)
if not total_patients:
continue
# Find child pathways (nodes whose ids start with this node's ids + " - ")
child_prefix = node_ids + " - "
child_patients = sum(
_get_patients(n)
for n in nodes
if n.get("ids", "").startswith(child_prefix) and n.get("level", 0) == level + 1
)
retained = total_patients - child_patients
retention_rate = (retained / total_patients * 100) if total_patients > 0 else 0.0
results[node_ids] = {
"retained_patients": retained,
"total_patients": total_patients,
"retention_rate": round(retention_rate, 1),
"drug_sequence": parse_pathway_drugs(node_ids, level),
}
return results
+642
View File
@@ -0,0 +1,642 @@
"""
Pathway data processing pipeline.
This module provides functions to:
1. Fetch and transform raw intervention data from Snowflake
2. Process data for each of the 6 date filter combinations
3. Extract denormalized fields from hierarchical path strings
4. Convert processed data to records for SQLite storage
The pipeline integrates with:
- analysis/pathway_analyzer.py: generate_icicle_chart() for pathway processing
- data_processing/snowflake_connector.py: fetch_activity_data() for data retrieval
- tools/data.py: patient_id(), drug_names(), department_identification()
"""
from dataclasses import dataclass
from datetime import date, timedelta
from typing import Optional, Literal
import json
import numpy as np
import pandas as pd
from core import PathConfig, default_paths
from core.logging_config import get_logger
from analysis.pathway_analyzer import generate_icicle_chart, generate_icicle_chart_indication
from data_processing.transforms import patient_id, drug_names, department_identification
logger = get_logger(__name__)
# Type alias for chart types
ChartType = Literal["directory", "indication"]
@dataclass
class DateFilterConfig:
"""Configuration for a date filter combination."""
id: str # e.g., 'all_6mo', '1yr_12mo'
initiated_years: Optional[int] # None for 'All', 1, or 2
last_seen_months: int # 6 or 12
# Pre-defined date filter configurations matching pathway_date_filters table
DATE_FILTER_CONFIGS = [
DateFilterConfig(id="all_6mo", initiated_years=None, last_seen_months=6),
DateFilterConfig(id="all_12mo", initiated_years=None, last_seen_months=12),
DateFilterConfig(id="1yr_6mo", initiated_years=1, last_seen_months=6),
DateFilterConfig(id="1yr_12mo", initiated_years=1, last_seen_months=12),
DateFilterConfig(id="2yr_6mo", initiated_years=2, last_seen_months=6),
DateFilterConfig(id="2yr_12mo", initiated_years=2, last_seen_months=12),
]
def compute_date_ranges(
config: DateFilterConfig,
max_date: Optional[date] = None,
) -> tuple[str, str, str]:
"""
Compute actual date strings from a date filter configuration.
Args:
config: DateFilterConfig with initiated_years and last_seen_months
max_date: Reference date (defaults to today)
Returns:
Tuple of (start_date, end_date, last_seen_date) as ISO format strings
- start_date: Start of initiated filter period
- end_date: End of initiated filter period (usually max_date)
- last_seen_date: Date threshold for last_seen filter
"""
if max_date is None:
max_date = date.today()
# Calculate end_date (always max_date)
end_date = max_date
# Calculate start_date based on initiated_years
if config.initiated_years is None:
# "All years" - use a very old date
start_date = date(2000, 1, 1)
else:
# Last N years from max_date
start_date = max_date.replace(year=max_date.year - config.initiated_years)
# Calculate last_seen_date based on last_seen_months
# Patients must have been seen within the last N months
last_seen_date = max_date - timedelta(days=config.last_seen_months * 30)
return (
start_date.isoformat(),
end_date.isoformat(),
last_seen_date.isoformat(),
)
def fetch_and_transform_data(
start_date: Optional[date] = None,
end_date: Optional[date] = None,
provider_codes: Optional[list[str]] = None,
paths: Optional[PathConfig] = None,
) -> pd.DataFrame:
"""
Fetch data from Snowflake and apply standard transformations.
This function:
1. Fetches raw intervention data from Snowflake
2. Applies UPID generation (Provider Code[:3] + PersonKey)
3. Standardizes drug names via drugnames.csv mapping
4. Assigns directories using the 5-level fallback logic
Args:
start_date: Optional start date filter for Snowflake query
end_date: Optional end date filter for Snowflake query
provider_codes: Optional list of provider codes to filter
paths: PathConfig for file paths (uses default if None)
Returns:
DataFrame with columns: UPID, Drug Name, Directory, Intervention Date,
Price Actual, Provider Code, PersonKey, OrganisationName, etc.
Raises:
ImportError: If snowflake-connector-python is not installed
SnowflakeConnectionError: If connection fails
"""
if paths is None:
paths = default_paths
# Import here to avoid circular imports and handle optional dependency
from data_processing.snowflake_connector import get_connector, is_snowflake_available
if not is_snowflake_available():
raise ImportError(
"snowflake-connector-python is not installed. "
"Install it with: pip install snowflake-connector-python"
)
logger.info("Fetching activity data from Snowflake...")
connector = get_connector()
raw_data = connector.fetch_activity_data(
start_date=start_date,
end_date=end_date,
provider_codes=provider_codes,
max_rows=0, # No limit
)
if not raw_data:
logger.warning("No data returned from Snowflake")
return pd.DataFrame()
logger.info(f"Fetched {len(raw_data)} records from Snowflake")
# Convert to DataFrame
df = pd.DataFrame(raw_data)
# Apply transformations in the standard order
logger.info("Applying data transformations...")
# 1. Generate UPID
df = patient_id(df)
logger.info(f"Generated UPID for {df['UPID'].nunique()} unique patients")
# 2. Standardize drug names
df = drug_names(df, paths)
# Remove rows where drug name mapping failed (NaN)
before_count = len(df)
df = df.dropna(subset=['Drug Name'])
after_count = len(df)
if before_count != after_count:
logger.info(f"Removed {before_count - after_count} rows with unmapped drug names")
# 3. Assign directories
df = department_identification(df, paths)
logger.info(f"Assigned directories to {len(df)} records")
# Ensure Intervention Date is datetime
df['Intervention Date'] = pd.to_datetime(df['Intervention Date'])
logger.info(f"Data transformation complete. Final record count: {len(df)}")
return df
def process_pathway_for_date_filter(
df: pd.DataFrame,
config: DateFilterConfig,
trust_filter: list[str],
drug_filter: list[str],
directory_filter: list[str],
minimum_patients: int = 5,
max_date: Optional[date] = None,
paths: Optional[PathConfig] = None,
) -> Optional[pd.DataFrame]:
"""
Process pathway data for a single date filter configuration.
Uses the existing generate_icicle_chart() function from pathway_analyzer.py
to build the pathway hierarchy with treatment statistics.
Args:
df: Transformed DataFrame from fetch_and_transform_data()
config: DateFilterConfig specifying the date filter combination
trust_filter: List of trust names to include
drug_filter: List of drug names to include
directory_filter: List of directories to include
minimum_patients: Minimum patients to include a pathway
max_date: Reference date for computing date ranges
paths: PathConfig for file paths
Returns:
DataFrame with pathway hierarchy (ice_df) or None if no data
"""
if paths is None:
paths = default_paths
# Compute actual date ranges for this filter config
start_date, end_date, last_seen_date = compute_date_ranges(config, max_date)
logger.info(f"Processing pathway for {config.id}")
logger.info(f" Date range: {start_date} to {end_date}")
logger.info(f" Last seen after: {last_seen_date}")
# Use the existing pathway analyzer
ice_df, title = generate_icicle_chart(
df=df,
start_date=start_date,
end_date=end_date,
last_seen_date=last_seen_date,
trust_filter=trust_filter,
drug_filter=drug_filter,
directory_filter=directory_filter,
minimum_num_patients=minimum_patients,
title="",
paths=paths,
)
if ice_df is None or len(ice_df) == 0:
logger.warning(f"No pathway data for filter {config.id}")
return None
logger.info(f"Generated {len(ice_df)} pathway nodes for {config.id}")
return ice_df
def process_indication_pathway_for_date_filter(
df: pd.DataFrame,
indication_df: pd.DataFrame,
config: DateFilterConfig,
trust_filter: list[str],
drug_filter: list[str],
directory_filter: list[str],
minimum_patients: int = 5,
max_date: Optional[date] = None,
paths: Optional[PathConfig] = None,
) -> Optional[pd.DataFrame]:
"""
Process indication-based pathway data for a single date filter configuration.
This is similar to process_pathway_for_date_filter() but uses indication-based
grouping (Search_Term from GP diagnosis) instead of directory grouping.
Hierarchy: Trust → Indication_Group → Drug → Pathway
Args:
df: Transformed DataFrame from fetch_and_transform_data()
indication_df: DataFrame with UPID → Indication_Group mapping
Must have columns: UPID, Indication_Group
Indication_Group is either Search_Term or "Directory (no GP dx)"
config: DateFilterConfig specifying the date filter combination
trust_filter: List of trust names to include
drug_filter: List of drug names to include
directory_filter: List of directories to include
minimum_patients: Minimum patients to include a pathway
max_date: Reference date for computing date ranges
paths: PathConfig for file paths
Returns:
DataFrame with pathway hierarchy (ice_df) or None if no data
"""
if paths is None:
paths = default_paths
# Compute actual date ranges for this filter config
start_date, end_date, last_seen_date = compute_date_ranges(config, max_date)
logger.info(f"Processing indication pathway for {config.id}")
logger.info(f" Date range: {start_date} to {end_date}")
logger.info(f" Last seen after: {last_seen_date}")
# Use the indication-aware pathway analyzer
ice_df, title = generate_icicle_chart_indication(
df=df,
indication_df=indication_df,
start_date=start_date,
end_date=end_date,
last_seen_date=last_seen_date,
trust_filter=trust_filter,
drug_filter=drug_filter,
directory_filter=directory_filter,
minimum_num_patients=minimum_patients,
title="",
paths=paths,
)
if ice_df is None or len(ice_df) == 0:
logger.warning(f"No indication pathway data for filter {config.id}")
return None
logger.info(f"Generated {len(ice_df)} indication pathway nodes for {config.id}")
return ice_df
def extract_denormalized_fields(ice_df: pd.DataFrame) -> pd.DataFrame:
"""
Extract denormalized filter columns from the ids column.
The ids column contains hierarchical paths like:
- "N&WICS" (root)
- "N&WICS - NNUH" (trust level)
- "N&WICS - NNUH - OPHTHALMOLOGY" (directory level)
- "N&WICS - NNUH - OPHTHALMOLOGY - RANIBIZUMAB" (first drug)
- "N&WICS - NNUH - OPHTHALMOLOGY - RANIBIZUMAB - AFLIBERCEPT" (pathway)
This function extracts:
- trust_name: The trust component (level 1)
- directory: The directory component (level 2)
- drug_sequence: Pipe-separated drugs (level 3+)
Args:
ice_df: DataFrame from generate_icicle_chart()
Returns:
DataFrame with added columns: trust_name, directory, drug_sequence
"""
df = ice_df.copy()
# Split ids by " - " delimiter
def extract_components(ids_str: str) -> tuple[str, str, str]:
"""Extract trust, directory, and drug sequence from ids string."""
if not ids_str or pd.isna(ids_str):
return ("", "", "")
parts = ids_str.split(" - ")
# Level 0: Root (e.g., "N&WICS")
if len(parts) <= 1:
return ("", "", "")
# Level 1+: Trust is always parts[1]
trust_name = parts[1] if len(parts) > 1 else ""
# Level 2+: Directory is parts[2]
directory = parts[2] if len(parts) > 2 else ""
# Level 3+: Drugs are parts[3:]
drugs = parts[3:] if len(parts) > 3 else []
drug_sequence = "|".join(drugs) if drugs else ""
return (trust_name, directory, drug_sequence)
# Apply extraction to all rows
extracted = df['ids'].apply(extract_components)
df['trust_name'] = extracted.apply(lambda x: x[0])
df['directory'] = extracted.apply(lambda x: x[1])
df['drug_sequence'] = extracted.apply(lambda x: x[2])
logger.info(f"Extracted denormalized fields for {len(df)} nodes")
logger.info(f" Unique trusts: {df['trust_name'].nunique()}")
logger.info(f" Unique directories: {df['directory'].nunique()}")
return df
def extract_indication_fields(ice_df: pd.DataFrame) -> pd.DataFrame:
"""
Extract denormalized filter columns from the ids column for indication charts.
Similar to extract_denormalized_fields() but for indication-based charts where
the level-2 grouping is Search_Term (or fallback directorate) instead of Directory.
The ids column contains hierarchical paths like:
- "N&WICS" (root)
- "N&WICS - NNUH" (trust level)
- "N&WICS - NNUH - rheumatoid arthritis" (search_term level - matched patient)
- "N&WICS - NNUH - RHEUMATOLOGY (no GP dx)" (fallback level - unmatched patient)
- "N&WICS - NNUH - rheumatoid arthritis - ADALIMUMAB" (first drug)
- "N&WICS - NNUH - rheumatoid arthritis - ADALIMUMAB - ETANERCEPT" (pathway)
This function extracts:
- trust_name: The trust component (level 1)
- search_term: The Search_Term or fallback directorate (level 2)
- drug_sequence: Pipe-separated drugs (level 3+)
Note: For indication charts, 'directory' column contains the search_term
to maintain schema compatibility with the pathway_nodes table.
Args:
ice_df: DataFrame from generate_icicle_chart() with indication grouping
Returns:
DataFrame with added columns: trust_name, directory (=search_term), drug_sequence
"""
df = ice_df.copy()
def extract_components(ids_str: str) -> tuple[str, str, str]:
"""Extract trust, search_term, and drug sequence from ids string."""
if not ids_str or pd.isna(ids_str):
return ("", "", "")
parts = ids_str.split(" - ")
# Level 0: Root (e.g., "N&WICS")
if len(parts) <= 1:
return ("", "", "")
# Level 1+: Trust is always parts[1]
trust_name = parts[1] if len(parts) > 1 else ""
# Level 2+: Search_term (or fallback) is parts[2]
search_term = parts[2] if len(parts) > 2 else ""
# Level 3+: Drugs are parts[3:]
drugs = parts[3:] if len(parts) > 3 else []
drug_sequence = "|".join(drugs) if drugs else ""
return (trust_name, search_term, drug_sequence)
# Apply extraction to all rows
extracted = df['ids'].apply(extract_components)
df['trust_name'] = extracted.apply(lambda x: x[0])
# Use 'directory' column to store search_term for schema compatibility
df['directory'] = extracted.apply(lambda x: x[1])
df['drug_sequence'] = extracted.apply(lambda x: x[2])
logger.info(f"Extracted indication fields for {len(df)} nodes")
logger.info(f" Unique trusts: {df['trust_name'].nunique()}")
logger.info(f" Unique search_terms: {df['directory'].nunique()}")
return df
def convert_to_records(
ice_df: pd.DataFrame,
date_filter_id: str,
refresh_id: Optional[str] = None,
chart_type: ChartType = "directory",
) -> list[dict]:
"""
Convert ice_df to a list of dictionaries for SQLite insertion.
Maps ice_df columns to pathway_nodes table schema:
- parents, ids, labels: Direct mapping
- level: From ice_df['level']
- value, cost, costpp, colour: Direct mapping
- cost_pp_pa: From ice_df['cost_pp_pa']
- first_seen, last_seen, first_seen_parent, last_seen_parent: Date columns
- average_spacing: From ice_df['average_spacing']
- average_administered: JSON serialization of list
- avg_days: From ice_df['avg_days']
- trust_name, directory, drug_sequence: Denormalized fields
- date_filter_id: The filter combination ID
- chart_type: "directory" or "indication"
- data_refresh_id: Optional refresh tracking ID
Args:
ice_df: DataFrame from generate_icicle_chart() with denormalized fields
date_filter_id: The date filter combination ID (e.g., 'all_6mo')
refresh_id: Optional refresh tracking ID
chart_type: Chart type - "directory" (default) or "indication"
Returns:
List of dictionaries ready for SQLite insertion
"""
records = []
for _, row in ice_df.iterrows():
# Handle date formatting
first_seen = None
last_seen = None
first_seen_parent = None
last_seen_parent = None
if pd.notna(row.get('First seen')):
if hasattr(row['First seen'], 'isoformat'):
first_seen = row['First seen'].isoformat()
else:
first_seen = str(row['First seen'])
if pd.notna(row.get('Last seen')):
if hasattr(row['Last seen'], 'isoformat'):
last_seen = row['Last seen'].isoformat()
else:
last_seen = str(row['Last seen'])
if pd.notna(row.get('First seen (Parent)')):
first_seen_parent = str(row['First seen (Parent)'])
if pd.notna(row.get('Last seen (Parent)')):
last_seen_parent = str(row['Last seen (Parent)'])
# Handle average_administered (could be list, ndarray, or None)
average_administered = None
val = row.get('average_administered')
if val is not None:
# Check for scalar None-like values
try:
if pd.isna(val):
average_administered = None
elif isinstance(val, (list, np.ndarray)):
average_administered = json.dumps(list(val) if hasattr(val, 'tolist') else val)
else:
average_administered = str(val)
except (ValueError, TypeError):
# pd.isna raises ValueError for arrays with >1 element
# In that case, val is an array/list, so convert to JSON
if hasattr(val, 'tolist'):
average_administered = json.dumps(val.tolist())
elif isinstance(val, list):
average_administered = json.dumps(val)
else:
average_administered = str(val)
record = {
'date_filter_id': date_filter_id,
'chart_type': chart_type,
'parents': str(row.get('parents', '')) if pd.notna(row.get('parents')) else '',
'ids': str(row.get('ids', '')) if pd.notna(row.get('ids')) else '',
'labels': str(row.get('labels', '')) if pd.notna(row.get('labels')) else '',
'level': int(row.get('level', 0)) if pd.notna(row.get('level')) else 0,
'value': int(row.get('value', 0)) if pd.notna(row.get('value')) else 0,
'cost': float(row.get('cost', 0)) if pd.notna(row.get('cost')) else 0.0,
'costpp': float(row.get('costpp')) if pd.notna(row.get('costpp')) else None,
'cost_pp_pa': str(row.get('cost_pp_pa', '')) if pd.notna(row.get('cost_pp_pa')) else None,
'colour': float(row.get('colour', 0)) if pd.notna(row.get('colour')) else 0.0,
'first_seen': first_seen,
'last_seen': last_seen,
'first_seen_parent': first_seen_parent,
'last_seen_parent': last_seen_parent,
'average_spacing': str(row.get('average_spacing', '')) if pd.notna(row.get('average_spacing')) else None,
'average_administered': average_administered,
'avg_days': float(row['avg_days'].total_seconds() / 86400) if pd.notna(row.get('avg_days')) and hasattr(row.get('avg_days'), 'total_seconds') else (float(row.get('avg_days')) if pd.notna(row.get('avg_days')) else None),
'trust_name': row.get('trust_name', '') if pd.notna(row.get('trust_name')) else None,
'directory': row.get('directory', '') if pd.notna(row.get('directory')) else None,
'drug_sequence': row.get('drug_sequence', '') if pd.notna(row.get('drug_sequence')) else None,
'data_refresh_id': refresh_id,
}
records.append(record)
logger.info(f"Converted {len(records)} pathway nodes to records for {date_filter_id} ({chart_type})")
return records
def process_all_date_filters(
df: pd.DataFrame,
trust_filter: list[str],
drug_filter: list[str],
directory_filter: list[str],
minimum_patients: int = 5,
max_date: Optional[date] = None,
refresh_id: Optional[str] = None,
paths: Optional[PathConfig] = None,
) -> dict[str, list[dict]]:
"""
Process pathway data for all 6 date filter combinations.
This is a convenience function that processes all DATE_FILTER_CONFIGS
and returns a dictionary of records ready for SQLite insertion.
Args:
df: Transformed DataFrame from fetch_and_transform_data()
trust_filter: List of trust names to include
drug_filter: List of drug names to include
directory_filter: List of directories to include
minimum_patients: Minimum patients to include a pathway
max_date: Reference date for computing date ranges
refresh_id: Optional refresh tracking ID
paths: PathConfig for file paths
Returns:
Dictionary mapping date_filter_id to list of record dicts
e.g., {"all_6mo": [...], "all_12mo": [...], ...}
"""
if paths is None:
paths = default_paths
results = {}
for config in DATE_FILTER_CONFIGS:
logger.info(f"Processing date filter: {config.id}")
# Process pathway for this date filter
ice_df = process_pathway_for_date_filter(
df=df,
config=config,
trust_filter=trust_filter,
drug_filter=drug_filter,
directory_filter=directory_filter,
minimum_patients=minimum_patients,
max_date=max_date,
paths=paths,
)
if ice_df is None:
logger.warning(f"Skipping {config.id} - no data")
results[config.id] = []
continue
# Extract denormalized fields
ice_df = extract_denormalized_fields(ice_df)
# Convert to records
records = convert_to_records(ice_df, config.id, refresh_id)
results[config.id] = records
logger.info(f"Completed {config.id}: {len(records)} nodes")
total_records = sum(len(r) for r in results.values())
logger.info(f"Total pathway nodes across all filters: {total_records}")
return results
# Export public API
__all__ = [
# Types
"ChartType",
# Data classes
"DateFilterConfig",
"DATE_FILTER_CONFIGS",
# Core functions
"compute_date_ranges",
"fetch_and_transform_data",
# Directory chart processing
"process_pathway_for_date_filter",
"extract_denormalized_fields",
# Indication chart processing
"process_indication_pathway_for_date_filter",
"extract_indication_fields",
# Common utilities
"convert_to_records",
"process_all_date_filters",
]
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+709
View File
@@ -0,0 +1,709 @@
"""
SQLite schema definitions for NHS High-Cost Drug Patient Pathway Analysis Tool.
Contains SQL strings for creating reference tables, fact tables, and indexes.
Schema design supports:
- Reference data from CSV files (drug names, organizations, directories)
- Drug-directory mappings with single-valid-directory flag
- Patient intervention facts with proper indexing
- Cached aggregations for performance
- File tracking for incremental updates
"""
from typing import Optional
import sqlite3
from core.logging_config import get_logger
logger = get_logger(__name__)
# =============================================================================
# Reference Table Schemas
# =============================================================================
REF_DRUG_NAMES_SCHEMA = """
-- Mapping from raw drug names (as they appear in source data) to standardized names
-- Source: data/drugnames.csv
CREATE TABLE IF NOT EXISTS ref_drug_names (
id INTEGER PRIMARY KEY AUTOINCREMENT,
raw_name TEXT NOT NULL UNIQUE,
standard_name TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
-- Index for fast lookups during data transformation
CREATE INDEX IF NOT EXISTS idx_ref_drug_names_raw ON ref_drug_names(raw_name);
CREATE INDEX IF NOT EXISTS idx_ref_drug_names_standard ON ref_drug_names(standard_name);
"""
REF_ORGANIZATIONS_SCHEMA = """
-- NHS organization codes and names
-- Source: data/org_codes.csv
CREATE TABLE IF NOT EXISTS ref_organizations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
org_code TEXT NOT NULL UNIQUE,
org_name TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
-- Index for fast lookups by organization code
CREATE INDEX IF NOT EXISTS idx_ref_organizations_code ON ref_organizations(org_code);
"""
REF_DIRECTORIES_SCHEMA = """
-- Medical directories/specialties
-- Source: data/directory_list.csv
CREATE TABLE IF NOT EXISTS ref_directories (
id INTEGER PRIMARY KEY AUTOINCREMENT,
directory_name TEXT NOT NULL UNIQUE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
-- Index for fast lookups by directory name
CREATE INDEX IF NOT EXISTS idx_ref_directories_name ON ref_directories(directory_name);
"""
REF_DRUG_DIRECTORY_MAP_SCHEMA = """
-- Mapping from drug names to valid directories
-- Source: data/drug_directory_list.csv
-- A drug may map to multiple directories (one row per drug-directory pair)
-- The is_single_valid flag indicates drugs with exactly ONE valid directory,
-- which enables automatic directory assignment in department_identification()
CREATE TABLE IF NOT EXISTS ref_drug_directory_map (
id INTEGER PRIMARY KEY AUTOINCREMENT,
drug_name TEXT NOT NULL,
directory_name TEXT NOT NULL,
is_single_valid BOOLEAN NOT NULL DEFAULT 0,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE(drug_name, directory_name)
);
-- Index for looking up directories by drug name (most common access pattern)
CREATE INDEX IF NOT EXISTS idx_ref_drug_directory_map_drug ON ref_drug_directory_map(drug_name);
-- Index for reverse lookup (find drugs by directory)
CREATE INDEX IF NOT EXISTS idx_ref_drug_directory_map_directory ON ref_drug_directory_map(directory_name);
-- Index for quick filtering of single-valid drugs
CREATE INDEX IF NOT EXISTS idx_ref_drug_directory_map_single ON ref_drug_directory_map(is_single_valid);
"""
REF_DRUG_INDICATION_CLUSTERS_SCHEMA = """
-- Mapping from drugs to SNOMED clusters for indication validation
-- Source: data/drug_indication_clusters.csv
-- Used to validate that patients have appropriate GP diagnoses for their prescribed drugs
-- A drug may map to multiple clusters (one row per drug-indication-cluster combination)
CREATE TABLE IF NOT EXISTS ref_drug_indication_clusters (
id INTEGER PRIMARY KEY AUTOINCREMENT,
drug_name TEXT NOT NULL,
indication TEXT NOT NULL,
cluster_id TEXT NOT NULL,
cluster_description TEXT,
nice_ta_reference TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE(drug_name, indication, cluster_id)
);
-- Index for looking up clusters by drug name (most common access pattern)
CREATE INDEX IF NOT EXISTS idx_ref_drug_indication_clusters_drug ON ref_drug_indication_clusters(drug_name);
-- Index for looking up drugs by cluster (for finding all drugs treating a condition)
CREATE INDEX IF NOT EXISTS idx_ref_drug_indication_clusters_cluster ON ref_drug_indication_clusters(cluster_id);
-- Index for looking up by indication text
CREATE INDEX IF NOT EXISTS idx_ref_drug_indication_clusters_indication ON ref_drug_indication_clusters(indication);
"""
# =============================================================================
# Pathway Data Architecture Schemas
# =============================================================================
PATHWAY_DATE_FILTERS_SCHEMA = """
-- Stores the 6 pre-computed date filter combinations
-- Each combination represents a different initiated/last_seen date range
-- Used to efficiently query pre-computed pathway data
CREATE TABLE IF NOT EXISTS pathway_date_filters (
id TEXT PRIMARY KEY, -- e.g., 'all_6mo', '1yr_12mo'
initiated_label TEXT NOT NULL, -- e.g., 'All years', 'Last 1 year', 'Last 2 years'
last_seen_label TEXT NOT NULL, -- e.g., 'Last 6 months', 'Last 12 months'
initiated_years INTEGER, -- NULL for 'All', 1, or 2
last_seen_months INTEGER NOT NULL, -- 6 or 12
is_default INTEGER DEFAULT 0, -- 1 for 'all_6mo' (default selection)
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
-- Pre-populate the 6 combinations
INSERT OR REPLACE INTO pathway_date_filters (id, initiated_label, last_seen_label, initiated_years, last_seen_months, is_default) VALUES
('all_6mo', 'All years', 'Last 6 months', NULL, 6, 1),
('all_12mo', 'All years', 'Last 12 months', NULL, 12, 0),
('1yr_6mo', 'Last 1 year', 'Last 6 months', 1, 6, 0),
('1yr_12mo', 'Last 1 year', 'Last 12 months', 1, 12, 0),
('2yr_6mo', 'Last 2 years', 'Last 6 months', 2, 6, 0),
('2yr_12mo', 'Last 2 years', 'Last 12 months', 2, 12, 0);
"""
PATHWAY_NODES_SCHEMA = """
-- Main pathway nodes table (one set per date filter + chart type combination)
-- Stores pre-computed pathway hierarchy with all visualization data
-- Designed for fast filtering by date_filter_id + chart_type + trust/directory/drug
CREATE TABLE IF NOT EXISTS pathway_nodes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
-- Date filter combination this belongs to
date_filter_id TEXT NOT NULL,
-- Chart type: "directory" (Trust→Directory→Drug) or "indication" (Trust→SearchTerm→Drug)
chart_type TEXT NOT NULL DEFAULT 'directory',
-- Hierarchy structure (for icicle chart)
parents TEXT NOT NULL, -- Parent node identifier
ids TEXT NOT NULL, -- Unique node identifier (hierarchical path)
labels TEXT NOT NULL, -- Display label
level INTEGER NOT NULL, -- Hierarchy depth (0=root, 1=trust, 2=directory/search_term, 3+=drugs)
-- Patient counts (accurate for this date filter combination)
value INTEGER NOT NULL DEFAULT 0, -- Patient count
-- Cost metrics
cost REAL NOT NULL DEFAULT 0.0, -- Total cost
costpp REAL, -- Cost per patient
cost_pp_pa TEXT, -- Cost per patient per annum (formatted string)
-- Visualization
colour REAL NOT NULL DEFAULT 0.0, -- Color value (proportion of parent)
-- Date ranges (for this node)
first_seen TEXT, -- First intervention date (ISO format)
last_seen TEXT, -- Last intervention date (ISO format)
first_seen_parent TEXT, -- Earliest date in parent group
last_seen_parent TEXT, -- Latest date in parent group
-- Treatment statistics
average_spacing TEXT, -- Formatted treatment duration string
average_administered TEXT, -- JSON array of average doses per drug
avg_days REAL, -- Average treatment duration in days
-- Denormalized filter columns (for efficient WHERE clause filtering)
trust_name TEXT, -- Extracted trust name from ids
directory TEXT, -- Extracted directory from ids
drug_sequence TEXT, -- Pipe-separated drug sequence from pathway
-- Metadata
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
data_refresh_id TEXT, -- Links to pathway_refresh_log
-- Unique per date filter + chart type + pathway
UNIQUE(date_filter_id, chart_type, ids),
FOREIGN KEY (date_filter_id) REFERENCES pathway_date_filters(id)
);
-- Indexes for efficient filtering
-- Primary filter: select by date_filter_id
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_date_filter ON pathway_nodes(date_filter_id);
-- Chart type filter: for switching between directory and indication views
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_chart_type ON pathway_nodes(date_filter_id, chart_type);
-- Level filter: often used with date_filter_id
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_level ON pathway_nodes(date_filter_id, level);
-- Trust filter: for Trust dropdown filtering
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_trust ON pathway_nodes(date_filter_id, trust_name);
-- Directory filter: for Directory dropdown filtering
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_directory ON pathway_nodes(date_filter_id, directory);
-- Drug sequence filter: for drug filtering (uses LIKE '%DRUG%')
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_drug_seq ON pathway_nodes(drug_sequence);
-- Parents filter: for finding children of a node
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_parents ON pathway_nodes(date_filter_id, parents);
-- Composite index for common filter combination
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_filter_composite
ON pathway_nodes(date_filter_id, chart_type, trust_name, directory);
"""
PATHWAY_REFRESH_LOG_SCHEMA = """
-- Metadata table for tracking refresh status
-- Tracks when pathway data was last refreshed from Snowflake
CREATE TABLE IF NOT EXISTS pathway_refresh_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
refresh_id TEXT NOT NULL, -- Unique identifier for this refresh run
started_at TEXT NOT NULL, -- ISO timestamp when refresh started
completed_at TEXT, -- ISO timestamp when refresh completed (NULL if still running)
status TEXT DEFAULT 'running', -- 'running', 'completed', 'failed'
record_count INTEGER, -- Total pathway_nodes records created
date_filter_counts TEXT, -- JSON: {"all_6mo": 1234, "all_12mo": 1567, ...}
error_message TEXT, -- Error details if status='failed'
snowflake_query_date_from TEXT, -- Start date of Snowflake query
snowflake_query_date_to TEXT, -- End date of Snowflake query
processing_duration_seconds REAL, -- How long the refresh took
source_row_count INTEGER, -- Number of Snowflake rows fetched
created_at TEXT DEFAULT CURRENT_TIMESTAMP
);
-- Index for finding latest refresh
CREATE INDEX IF NOT EXISTS idx_pathway_refresh_log_started ON pathway_refresh_log(started_at DESC);
-- Index for finding by status
CREATE INDEX IF NOT EXISTS idx_pathway_refresh_log_status ON pathway_refresh_log(status);
"""
# Combined pathway schema
PATHWAY_TABLES_SCHEMA = f"""
-- Pathway Data Architecture Tables
-- Pre-computed pathway data for fast Reflex filtering
{PATHWAY_DATE_FILTERS_SCHEMA}
{PATHWAY_NODES_SCHEMA}
{PATHWAY_REFRESH_LOG_SCHEMA}
"""
# =============================================================================
# Combined Schemas
# =============================================================================
REFERENCE_TABLES_SCHEMA = f"""
-- Reference Tables Schema
-- Contains lookup data migrated from CSV files
{REF_DRUG_NAMES_SCHEMA}
{REF_ORGANIZATIONS_SCHEMA}
{REF_DIRECTORIES_SCHEMA}
{REF_DRUG_DIRECTORY_MAP_SCHEMA}
{REF_DRUG_INDICATION_CLUSTERS_SCHEMA}
"""
ALL_TABLES_SCHEMA = f"""
-- Complete Database Schema
-- Reference tables + Pathway tables
{REFERENCE_TABLES_SCHEMA}
{PATHWAY_TABLES_SCHEMA}
"""
# =============================================================================
# Schema Helper Functions
# =============================================================================
def create_reference_tables(conn: sqlite3.Connection) -> None:
"""
Create all reference tables in the database.
Args:
conn: SQLite database connection.
"""
logger.info("Creating reference tables...")
conn.executescript(REFERENCE_TABLES_SCHEMA)
logger.info("Reference tables created successfully")
def drop_reference_tables(conn: sqlite3.Connection) -> None:
"""
Drop all reference tables from the database.
Args:
conn: SQLite database connection.
Warning:
This will delete all reference data. Use with caution.
"""
logger.warning("Dropping reference tables...")
conn.executescript("""
DROP TABLE IF EXISTS ref_drug_names;
DROP TABLE IF EXISTS ref_organizations;
DROP TABLE IF EXISTS ref_directories;
DROP TABLE IF EXISTS ref_drug_directory_map;
DROP TABLE IF EXISTS ref_drug_indication_clusters;
""")
logger.info("Reference tables dropped")
def get_reference_table_counts(conn: sqlite3.Connection) -> dict[str, int]:
"""
Get row counts for all reference tables.
Args:
conn: SQLite database connection.
Returns:
Dictionary mapping table name to row count.
"""
tables = [
"ref_drug_names",
"ref_organizations",
"ref_directories",
"ref_drug_directory_map",
"ref_drug_indication_clusters",
]
counts = {}
for table in tables:
try:
cursor = conn.execute(f"SELECT COUNT(*) FROM {table}")
result = cursor.fetchone()
counts[table] = result[0] if result else 0
except sqlite3.OperationalError:
counts[table] = 0
return counts
def verify_reference_tables_exist(conn: sqlite3.Connection) -> list[str]:
"""
Verify that all reference tables exist.
Args:
conn: SQLite database connection.
Returns:
List of missing table names. Empty list means all tables exist.
"""
required_tables = [
"ref_drug_names",
"ref_organizations",
"ref_directories",
"ref_drug_directory_map",
"ref_drug_indication_clusters",
]
missing = []
for table in required_tables:
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
(table,)
)
if cursor.fetchone() is None:
missing.append(table)
return missing
# =============================================================================
# Pathway Table Helper Functions
# =============================================================================
def create_pathway_tables(conn: sqlite3.Connection) -> None:
"""
Create pathway data architecture tables in the database.
Creates:
- pathway_date_filters: 6 pre-defined date filter combinations
- pathway_nodes: Pre-computed pathway hierarchy data
- pathway_refresh_log: Refresh tracking metadata
Args:
conn: SQLite database connection.
"""
logger.info("Creating pathway tables...")
conn.executescript(PATHWAY_TABLES_SCHEMA)
logger.info("Pathway tables created successfully")
def drop_pathway_tables(conn: sqlite3.Connection) -> None:
"""
Drop pathway data architecture tables from the database.
Args:
conn: SQLite database connection.
Warning:
This will delete all pre-computed pathway data.
"""
logger.warning("Dropping pathway tables...")
conn.executescript("""
DROP TABLE IF EXISTS pathway_nodes;
DROP TABLE IF EXISTS pathway_refresh_log;
DROP TABLE IF EXISTS pathway_date_filters;
""")
logger.info("Pathway tables dropped")
def get_pathway_table_counts(conn: sqlite3.Connection) -> dict[str, int]:
"""
Get row counts for pathway tables.
Args:
conn: SQLite database connection.
Returns:
Dictionary mapping table name to row count.
"""
tables = ["pathway_date_filters", "pathway_nodes", "pathway_refresh_log"]
counts = {}
for table in tables:
try:
cursor = conn.execute(f"SELECT COUNT(*) FROM {table}")
result = cursor.fetchone()
counts[table] = result[0] if result else 0
except sqlite3.OperationalError:
# Table doesn't exist yet
counts[table] = 0
return counts
def verify_pathway_tables_exist(conn: sqlite3.Connection) -> list[str]:
"""
Verify that pathway tables exist.
Args:
conn: SQLite database connection.
Returns:
List of missing table names. Empty list means all tables exist.
"""
required_tables = ["pathway_date_filters", "pathway_nodes", "pathway_refresh_log"]
missing = []
for table in required_tables:
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
(table,)
)
if cursor.fetchone() is None:
missing.append(table)
return missing
def clear_pathway_nodes(conn: sqlite3.Connection, date_filter_id: str | None = None) -> int:
"""
Clear pathway nodes, optionally for a specific date filter.
Args:
conn: SQLite database connection.
date_filter_id: If provided, only clear nodes for this date filter.
If None, clear all pathway nodes.
Returns:
Number of rows deleted.
"""
if date_filter_id:
cursor = conn.execute(
"DELETE FROM pathway_nodes WHERE date_filter_id = ?",
(date_filter_id,)
)
else:
cursor = conn.execute("DELETE FROM pathway_nodes")
deleted_count = cursor.rowcount
conn.commit()
logger.info(f"Cleared {deleted_count} pathway nodes")
return deleted_count
def get_pathway_refresh_status(conn: sqlite3.Connection) -> dict | None:
"""
Get the status of the most recent pathway refresh.
Args:
conn: SQLite database connection.
Returns:
Dictionary with refresh status, or None if no refresh has been done.
"""
try:
cursor = conn.execute("""
SELECT refresh_id, started_at, completed_at, status, record_count,
date_filter_counts, error_message, processing_duration_seconds
FROM pathway_refresh_log
ORDER BY started_at DESC
LIMIT 1
""")
row = cursor.fetchone()
if row:
return {
"refresh_id": row[0],
"started_at": row[1],
"completed_at": row[2],
"status": row[3],
"record_count": row[4],
"date_filter_counts": row[5],
"error_message": row[6],
"processing_duration_seconds": row[7],
}
return None
except sqlite3.OperationalError:
# Table doesn't exist yet
return None
def migrate_pathway_nodes_chart_type(conn: sqlite3.Connection) -> tuple[bool, str]:
"""
Migrate pathway_nodes table to add chart_type column.
This migration:
1. Checks if chart_type column already exists
2. If not, adds it with DEFAULT 'directory'
3. Updates existing rows to have 'directory' chart_type
4. Adds index for efficient filtering
Args:
conn: SQLite database connection.
Returns:
Tuple of (success: bool, message: str)
"""
# Check if table exists
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='pathway_nodes'"
)
if cursor.fetchone() is None:
return True, "pathway_nodes table does not exist yet (will be created with chart_type column)"
# Check if chart_type column already exists
cursor = conn.execute("PRAGMA table_info(pathway_nodes)")
columns = [row[1] for row in cursor.fetchall()]
if "chart_type" in columns:
return True, "chart_type column already exists in pathway_nodes"
# Add chart_type column
logger.info("Adding chart_type column to pathway_nodes table...")
try:
# Add column with default value
conn.execute("""
ALTER TABLE pathway_nodes
ADD COLUMN chart_type TEXT NOT NULL DEFAULT 'directory'
""")
# Create index for efficient filtering by chart type
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_chart_type
ON pathway_nodes(date_filter_id, chart_type)
""")
# Update existing composite index (need to drop and recreate)
# Note: SQLite doesn't support DROP INDEX IF EXISTS in older versions,
# so we use a try/except
try:
conn.execute("DROP INDEX idx_pathway_nodes_filter_composite")
except sqlite3.OperationalError:
pass # Index didn't exist
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_filter_composite
ON pathway_nodes(date_filter_id, chart_type, trust_name, directory)
""")
# Need to recreate unique constraint since it changed
# SQLite doesn't support ALTER TABLE to change constraints, but
# since we're adding a column with a default value and the old
# constraint was (date_filter_id, ids), the new constraint
# (date_filter_id, chart_type, ids) will be satisfied by all existing
# rows since they all have chart_type='directory'
conn.commit()
logger.info("chart_type column added successfully")
# Count updated rows
cursor = conn.execute("SELECT COUNT(*) FROM pathway_nodes")
row_count = cursor.fetchone()[0]
return True, f"Added chart_type column, {row_count} existing rows set to 'directory'"
except Exception as e:
logger.error(f"Failed to add chart_type column: {e}")
return False, f"Migration failed: {e}"
def migrate_refresh_log_source_row_count(conn: sqlite3.Connection) -> tuple[bool, str]:
"""Add source_row_count column to pathway_refresh_log if it doesn't exist.
This column stores the Snowflake row count for display in the UI footer.
"""
cursor = conn.execute("PRAGMA table_info(pathway_refresh_log)")
columns = [row[1] for row in cursor.fetchall()]
if "source_row_count" in columns:
return True, "source_row_count column already exists"
logger.info("Adding source_row_count column to pathway_refresh_log...")
try:
conn.execute("""
ALTER TABLE pathway_refresh_log
ADD COLUMN source_row_count INTEGER
""")
conn.commit()
return True, "Added source_row_count column"
except Exception as e:
logger.error(f"Failed to add source_row_count column: {e}")
return False, f"Migration failed: {e}"
# =============================================================================
# Combined Helper Functions
# =============================================================================
def create_all_tables(conn: sqlite3.Connection) -> None:
"""
Create all tables (reference + pathway) in the database.
Args:
conn: SQLite database connection.
"""
logger.info("Creating all database tables...")
conn.executescript(ALL_TABLES_SCHEMA)
logger.info("All tables created successfully")
def drop_all_tables(conn: sqlite3.Connection) -> None:
"""
Drop all tables from the database.
Args:
conn: SQLite database connection.
Warning:
This will delete all data. Use with extreme caution.
"""
logger.warning("Dropping all tables...")
drop_pathway_tables(conn)
drop_reference_tables(conn)
logger.info("All tables dropped")
def get_all_table_counts(conn: sqlite3.Connection) -> dict[str, int]:
"""
Get row counts for all tables.
Args:
conn: SQLite database connection.
Returns:
Dictionary mapping table name to row count.
"""
counts = {}
counts.update(get_reference_table_counts(conn))
counts.update(get_pathway_table_counts(conn))
return counts
def verify_all_tables_exist(conn: sqlite3.Connection) -> list[str]:
"""
Verify that all tables exist.
Args:
conn: SQLite database connection.
Returns:
List of missing table names. Empty list means all tables exist.
"""
missing = []
missing.extend(verify_reference_tables_exist(conn))
missing.extend(verify_pathway_tables_exist(conn))
return missing
+797
View File
@@ -0,0 +1,797 @@
"""
Snowflake connector module for NHS Patient Pathway Analysis.
Provides connection handling with SSO browser authentication for NHS environments.
Uses the externalbrowser authenticator which opens a browser window for NHS identity
management authentication.
Usage:
from data_processing.snowflake_connector import SnowflakeConnector, get_connector
# Using context manager (recommended)
with get_connector() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM table LIMIT 10")
results = cursor.fetchall()
# Manual connection management
connector = SnowflakeConnector()
try:
conn = connector.connect()
cursor = conn.cursor()
# ... use cursor ...
finally:
connector.close()
"""
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import date, datetime
from pathlib import Path
from typing import Any, Generator, Optional, TYPE_CHECKING
import time
# Snowflake connector is an optional dependency
SNOWFLAKE_AVAILABLE = False
try:
import snowflake.connector
from snowflake.connector import SnowflakeConnection
from snowflake.connector.cursor import SnowflakeCursor
SNOWFLAKE_AVAILABLE = True
except ImportError:
snowflake = None # type: ignore[assignment]
# Type hints for when snowflake is not available
if TYPE_CHECKING:
from snowflake.connector import SnowflakeConnection
from snowflake.connector.cursor import SnowflakeCursor
from config import get_snowflake_config, SnowflakeConfig
from core.logging_config import get_logger
logger = get_logger(__name__)
class SnowflakeConnectionError(Exception):
"""Raised when Snowflake connection fails."""
pass
class SnowflakeNotConfiguredError(Exception):
"""Raised when Snowflake is not configured (no account)."""
pass
class SnowflakeNotAvailableError(Exception):
"""Raised when snowflake-connector-python is not installed."""
pass
@dataclass
class ConnectionInfo:
"""Information about the current connection state."""
connected: bool = False
account: str = ""
warehouse: str = ""
database: str = ""
schema: str = ""
user: str = ""
role: str = ""
connected_at: Optional[datetime] = None
last_query_at: Optional[datetime] = None
query_count: int = 0
class SnowflakeConnector:
"""
Manages Snowflake connections with SSO browser authentication.
This class provides connection management for NHS Snowflake access using
the externalbrowser authenticator which triggers NHS SSO login via browser.
Attributes:
config: SnowflakeConfig with connection settings
connection_info: ConnectionInfo tracking current state
Example:
connector = SnowflakeConnector()
with connector.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT CURRENT_USER()")
print(cursor.fetchone()[0])
"""
def __init__(self, config: Optional[SnowflakeConfig] = None):
"""
Initialize the connector with configuration.
Args:
config: Optional SnowflakeConfig. If not provided, loads from
config/snowflake.toml using get_snowflake_config().
"""
self._config = config or get_snowflake_config()
self._connection: Optional[SnowflakeConnection] = None
self._connection_info = ConnectionInfo()
@property
def config(self) -> SnowflakeConfig:
"""Return the Snowflake configuration."""
return self._config
@property
def connection_info(self) -> ConnectionInfo:
"""Return information about the current connection state."""
return self._connection_info
@property
def is_connected(self) -> bool:
"""Return True if currently connected to Snowflake."""
return self._connection is not None and not self._connection.is_closed()
def _check_availability(self) -> None:
"""Check that snowflake-connector-python is installed."""
if not SNOWFLAKE_AVAILABLE:
raise SnowflakeNotAvailableError(
"snowflake-connector-python is not installed. "
"Install it with: pip install snowflake-connector-python"
)
def _check_configured(self) -> None:
"""Check that Snowflake is configured."""
if not self._config.is_configured:
raise SnowflakeNotConfiguredError(
"Snowflake account is not configured. "
"Edit config/snowflake.toml and set connection.account"
)
def connect(self) -> SnowflakeConnection:
"""
Establish a connection to Snowflake.
Uses the externalbrowser authenticator which opens a browser window
for NHS SSO authentication. The browser popup is expected and normal.
Returns:
Active SnowflakeConnection
Raises:
SnowflakeNotAvailableError: If snowflake-connector-python not installed
SnowflakeNotConfiguredError: If account is not configured
SnowflakeConnectionError: If connection fails
"""
self._check_availability()
self._check_configured()
# Close existing connection if any
if self._connection is not None:
self.close()
conn_cfg = self._config.connection
timeout_cfg = self._config.timeouts
logger.info(f"Connecting to Snowflake account: {conn_cfg.account}")
logger.info(f"Using warehouse: {conn_cfg.warehouse}, database: {conn_cfg.database}")
logger.info(f"Authenticator: {conn_cfg.authenticator}")
if conn_cfg.authenticator == "externalbrowser":
logger.info("Browser window will open for NHS SSO authentication")
start_time = time.time()
try:
# Build connection parameters
connect_params = {
"account": conn_cfg.account,
"warehouse": conn_cfg.warehouse,
"database": conn_cfg.database,
"schema": conn_cfg.schema,
"authenticator": conn_cfg.authenticator,
"login_timeout": timeout_cfg.login_timeout,
"network_timeout": timeout_cfg.connection_timeout,
}
# Optional parameters (only add if set)
if conn_cfg.user:
connect_params["user"] = conn_cfg.user
if conn_cfg.role:
connect_params["role"] = conn_cfg.role
self._connection = snowflake.connector.connect(**connect_params)
elapsed = time.time() - start_time
logger.info(f"Connected to Snowflake successfully in {elapsed:.1f}s")
# Update connection info
self._connection_info = ConnectionInfo(
connected=True,
account=conn_cfg.account,
warehouse=conn_cfg.warehouse,
database=conn_cfg.database,
schema=conn_cfg.schema,
user=self._get_current_user(),
role=self._get_current_role(),
connected_at=datetime.now(),
query_count=0,
)
return self._connection
except Exception as e:
elapsed = time.time() - start_time
logger.error(f"Failed to connect to Snowflake after {elapsed:.1f}s: {e}")
self._connection_info = ConnectionInfo(connected=False)
raise SnowflakeConnectionError(f"Failed to connect to Snowflake: {e}") from e
def close(self) -> None:
"""Close the Snowflake connection if open."""
if self._connection is not None:
try:
self._connection.close()
logger.info("Snowflake connection closed")
except Exception as e:
logger.warning(f"Error closing Snowflake connection: {e}")
finally:
self._connection = None
self._connection_info = ConnectionInfo(connected=False)
def _get_current_user(self) -> str:
"""Get the current authenticated user."""
if self._connection is None:
return ""
try:
cursor = self._connection.cursor()
cursor.execute("SELECT CURRENT_USER()")
result = cursor.fetchone()
return result[0] if result else ""
except Exception:
return ""
def _get_current_role(self) -> str:
"""Get the current active role."""
if self._connection is None:
return ""
try:
cursor = self._connection.cursor()
cursor.execute("SELECT CURRENT_ROLE()")
result = cursor.fetchone()
return result[0] if result else ""
except Exception:
return ""
@contextmanager
def get_connection(self) -> Generator[SnowflakeConnection, None, None]:
"""
Context manager for connection handling.
Creates a new connection if not already connected, yields the connection,
and ensures proper cleanup on exit.
Yields:
Active SnowflakeConnection
Example:
connector = SnowflakeConnector()
with connector.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT 1")
"""
if not self.is_connected:
self.connect()
assert self._connection is not None, "Connection should be established"
try:
yield self._connection
finally:
# Keep connection open for reuse
pass
@contextmanager
def get_cursor(
self,
dict_cursor: bool = False
) -> Generator[SnowflakeCursor, None, None]:
"""
Context manager that provides a cursor.
Args:
dict_cursor: If True, returns cursor that yields dict-like rows
Yields:
SnowflakeCursor for executing queries
Example:
connector = SnowflakeConnector()
with connector.get_cursor() as cursor:
cursor.execute("SELECT * FROM table LIMIT 10")
for row in cursor:
print(row)
"""
if not self.is_connected:
self.connect()
assert self._connection is not None, "Connection should be established"
cursor: Any = None
try:
if dict_cursor:
cursor = self._connection.cursor(snowflake.connector.DictCursor) # type: ignore[union-attr]
else:
cursor = self._connection.cursor()
yield cursor # type: ignore[misc]
self._connection_info.last_query_at = datetime.now()
self._connection_info.query_count += 1
finally:
if cursor is not None:
cursor.close()
def execute(
self,
query: str,
params: Optional[tuple] = None,
timeout: Optional[int] = None
) -> list[tuple]:
"""
Execute a query and return all results.
Args:
query: SQL query to execute
params: Optional query parameters for parameterized queries
timeout: Optional query timeout in seconds (overrides config)
Returns:
List of result rows as tuples
Raises:
SnowflakeConnectionError: If not connected
Various snowflake errors for query issues
"""
if not self.is_connected:
self.connect()
effective_timeout = timeout or self._config.timeouts.query_timeout
with self.get_cursor() as cursor:
logger.info(f"Executing query (timeout={effective_timeout}s)")
logger.debug(f"Query: {query[:200]}...")
if effective_timeout > 0:
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
start_time = time.time()
cursor.execute(query, params)
results = cursor.fetchall()
elapsed = time.time() - start_time
logger.info(f"Query returned {len(results)} rows in {elapsed:.2f}s")
return results
def execute_dict(
self,
query: str,
params: Optional[tuple] = None,
timeout: Optional[int] = None
) -> list[dict]:
"""
Execute a query and return results as list of dictionaries.
Args:
query: SQL query to execute
params: Optional query parameters
timeout: Optional query timeout in seconds
Returns:
List of result rows as dictionaries
"""
if not self.is_connected:
self.connect()
effective_timeout = timeout or self._config.timeouts.query_timeout
with self.get_cursor(dict_cursor=True) as cursor:
logger.info(f"Executing query (timeout={effective_timeout}s)")
logger.debug(f"Query: {query[:200]}...")
if effective_timeout > 0:
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
start_time = time.time()
cursor.execute(query, params)
results = cursor.fetchall()
elapsed = time.time() - start_time
logger.info(f"Query returned {len(results)} rows in {elapsed:.2f}s")
return results # type: ignore[return-value]
def execute_chunked(
self,
query: str,
params: Optional[tuple] = None,
chunk_size: Optional[int] = None,
timeout: Optional[int] = None,
max_rows: Optional[int] = None,
) -> Generator[list[tuple], None, None]:
"""
Execute a query and yield results in chunks for memory efficiency.
This method is useful for large result sets that would exceed memory
if loaded all at once. Results are yielded as chunks of rows.
Args:
query: SQL query to execute
params: Optional query parameters for parameterized queries
chunk_size: Number of rows per chunk (default from config)
timeout: Optional query timeout in seconds (overrides config)
max_rows: Maximum total rows to return (default from config, 0 for no limit)
Yields:
List of result rows as tuples for each chunk
Example:
for chunk in connector.execute_chunked("SELECT * FROM large_table"):
process_chunk(chunk)
"""
if not self.is_connected:
self.connect()
effective_timeout = timeout or self._config.timeouts.query_timeout
effective_chunk_size = chunk_size or self._config.query.chunk_size
effective_max_rows = max_rows if max_rows is not None else self._config.query.max_rows
with self.get_cursor() as cursor:
logger.info(f"Executing chunked query (chunk_size={effective_chunk_size}, timeout={effective_timeout}s)")
logger.debug(f"Query: {query[:200]}...")
if effective_timeout > 0:
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
start_time = time.time()
cursor.execute(query, params)
total_rows = 0
chunk_num = 0
while True:
# Determine how many rows to fetch this chunk
if effective_max_rows > 0:
remaining = effective_max_rows - total_rows
if remaining <= 0:
break
fetch_size = min(effective_chunk_size, remaining)
else:
fetch_size = effective_chunk_size
chunk = cursor.fetchmany(fetch_size)
if not chunk:
break
chunk_num += 1
total_rows += len(chunk)
logger.debug(f"Chunk {chunk_num}: {len(chunk)} rows (total: {total_rows})")
yield chunk
elapsed = time.time() - start_time
logger.info(f"Chunked query returned {total_rows} rows in {chunk_num} chunks ({elapsed:.2f}s)")
def execute_chunked_dict(
self,
query: str,
params: Optional[tuple] = None,
chunk_size: Optional[int] = None,
timeout: Optional[int] = None,
max_rows: Optional[int] = None,
) -> Generator[list[dict], None, None]:
"""
Execute a query and yield dict results in chunks for memory efficiency.
Same as execute_chunked but returns rows as dictionaries.
Args:
query: SQL query to execute
params: Optional query parameters
chunk_size: Number of rows per chunk (default from config)
timeout: Optional query timeout in seconds
max_rows: Maximum total rows to return (default from config, 0 for no limit)
Yields:
List of result rows as dictionaries for each chunk
"""
if not self.is_connected:
self.connect()
effective_timeout = timeout or self._config.timeouts.query_timeout
effective_chunk_size = chunk_size or self._config.query.chunk_size
effective_max_rows = max_rows if max_rows is not None else self._config.query.max_rows
with self.get_cursor(dict_cursor=True) as cursor:
logger.info(f"Executing chunked dict query (chunk_size={effective_chunk_size}, timeout={effective_timeout}s)")
logger.debug(f"Query: {query[:200]}...")
if effective_timeout > 0:
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
start_time = time.time()
cursor.execute(query, params)
total_rows = 0
chunk_num = 0
while True:
# Determine how many rows to fetch this chunk
if effective_max_rows > 0:
remaining = effective_max_rows - total_rows
if remaining <= 0:
break
fetch_size = min(effective_chunk_size, remaining)
else:
fetch_size = effective_chunk_size
chunk = cursor.fetchmany(fetch_size)
if not chunk:
break
chunk_num += 1
total_rows += len(chunk)
logger.debug(f"Chunk {chunk_num}: {len(chunk)} rows (total: {total_rows})")
yield chunk # type: ignore[misc]
elapsed = time.time() - start_time
logger.info(f"Chunked dict query returned {total_rows} rows in {chunk_num} chunks ({elapsed:.2f}s)")
def execute_with_row_limit(
self,
query: str,
params: Optional[tuple] = None,
max_rows: Optional[int] = None,
timeout: Optional[int] = None
) -> tuple[list[dict], bool]:
"""
Execute a query with a row limit and indicate if more rows were available.
This is useful for pagination or previewing large result sets.
Args:
query: SQL query to execute
params: Optional query parameters
max_rows: Maximum rows to return (default from config)
timeout: Optional query timeout in seconds
Returns:
Tuple of (results list, has_more bool)
- results: List of result rows as dictionaries (up to max_rows)
- has_more: True if there were more rows than max_rows
"""
if not self.is_connected:
self.connect()
effective_timeout = timeout or self._config.timeouts.query_timeout
effective_max_rows = max_rows if max_rows is not None else self._config.query.max_rows
with self.get_cursor(dict_cursor=True) as cursor:
logger.info(f"Executing query with limit (max_rows={effective_max_rows}, timeout={effective_timeout}s)")
logger.debug(f"Query: {query[:200]}...")
if effective_timeout > 0:
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
start_time = time.time()
cursor.execute(query, params)
# Fetch one more than max to detect if there are more rows
results = cursor.fetchmany(effective_max_rows + 1)
elapsed = time.time() - start_time
has_more = len(results) > effective_max_rows
if has_more:
results = results[:effective_max_rows]
logger.info(f"Query returned {len(results)} rows (has_more={has_more}) in {elapsed:.2f}s")
return results, has_more # type: ignore[return-value]
def fetch_activity_data(
self,
start_date: Optional[date] = None,
end_date: Optional[date] = None,
provider_codes: Optional[list[str]] = None,
max_rows: Optional[int] = None,
timeout: Optional[int] = None,
) -> list[dict]:
"""
Fetch high-cost drug activity data from Snowflake.
Queries the CDM.Acute__Conmon__PatientLevelDrugs table and returns
data in a format compatible with the existing analysis pipeline.
Args:
start_date: Optional start date for filtering (inclusive)
end_date: Optional end date for filtering (inclusive)
provider_codes: Optional list of provider codes to filter by
max_rows: Maximum rows to return (default from config)
timeout: Query timeout in seconds (default from config)
Returns:
List of dictionaries with keys matching expected DataFrame columns:
- PseudoNHSNoLinked: Pseudonymised NHS number (for UPID creation)
- Provider Code: NHS provider code
- PersonKey: Local patient identifier
- Drug Name: Raw drug name
- Intervention Date: Date of intervention
- Price Actual: Cost of intervention
- OrganisationName: Provider organisation name
- Treatment Function Code: NHS treatment function code
- Additional Detail 1-5: Additional details for directory identification
Raises:
SnowflakeConnectionError: If not connected or query fails
"""
if not self.is_connected:
self.connect()
# Build the query
table_name = 'DATA_HUB.CDM."Acute__Conmon__PatientLevelDrugs"'
query = f'''
SELECT
"PseudoNHSNoLinked",
"ProviderCode" AS "Provider Code",
"LocalPatientID" AS "PersonKey",
"DrugName" AS "Drug Name",
"InterventionDate" AS "Intervention Date",
"PriceActual" AS "Price Actual",
"ProviderName" AS "OrganisationName",
"TreatmentFunctionCode" AS "Treatment Function Code",
"TreatmentFunctionDesc" AS "Treatment Function Desc",
"AdditionalDetail1" AS "Additional Detail 1",
"AdditionalDescription1" AS "Additional Description 1",
"AdditionalDetail2" AS "Additional Detail 2",
"AdditionalDescription2" AS "Additional Description 2",
"AdditionalDetail3" AS "Additional Detail 3",
"AdditionalDescription3" AS "Additional Description 3",
"AdditionalDetail4" AS "Additional Detail 4",
"AdditionalDescription4" AS "Additional Description 4",
"AdditionalDetail5" AS "Additional Detail 5",
"AdditionalDescription5" AS "Additional Description 5"
FROM {table_name}
WHERE 1=1
'''
params = []
# Add date filters
if start_date:
query += ' AND "InterventionDate" >= %s'
params.append(start_date.isoformat())
if end_date:
query += ' AND "InterventionDate" <= %s'
params.append(end_date.isoformat())
# Add provider filter
if provider_codes:
placeholders = ", ".join(["%s"] * len(provider_codes))
query += f' AND "ProviderCode" IN ({placeholders})'
params.extend(provider_codes)
# Add ordering for consistent results
query += ' ORDER BY "InterventionDate", "ProviderCode", "PseudoNHSNoLinked"'
logger.info(f"Fetching activity data from Snowflake")
if start_date:
logger.info(f" Date range: {start_date} to {end_date or 'now'}")
if provider_codes:
logger.info(f" Providers: {provider_codes}")
effective_max_rows = max_rows if max_rows is not None else self._config.query.max_rows
effective_timeout = timeout or self._config.timeouts.query_timeout
# Execute with chunked results for large datasets
all_results = []
total_rows = 0
for chunk in self.execute_chunked_dict(
query,
params=tuple(params) if params else None,
timeout=effective_timeout,
max_rows=effective_max_rows,
):
all_results.extend(chunk)
total_rows += len(chunk)
logger.debug(f"Fetched {total_rows} rows so far...")
logger.info(f"Fetched {len(all_results)} activity records from Snowflake")
return all_results
def test_connection(self) -> tuple[bool, str]:
"""
Test the Snowflake connection.
Returns:
Tuple of (success: bool, message: str)
"""
try:
self._check_availability()
except SnowflakeNotAvailableError as e:
return False, str(e)
try:
self._check_configured()
except SnowflakeNotConfiguredError as e:
return False, str(e)
try:
self.connect()
user = self._get_current_user()
role = self._get_current_role()
return True, f"Connected as {user} with role {role}"
except Exception as e:
return False, f"Connection failed: {e}"
def __enter__(self) -> "SnowflakeConnector":
"""Context manager entry."""
self.connect()
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Context manager exit."""
self.close()
# Module-level singleton for convenience
_default_connector: Optional[SnowflakeConnector] = None
def get_connector(config: Optional[SnowflakeConfig] = None) -> SnowflakeConnector:
"""
Get a Snowflake connector (creates singleton on first call).
Args:
config: Optional configuration. If provided, creates new connector
with this config. If None, uses/creates default connector.
Returns:
SnowflakeConnector instance
"""
global _default_connector
if config is not None:
# Custom config requested, create new connector
return SnowflakeConnector(config)
if _default_connector is None:
_default_connector = SnowflakeConnector()
return _default_connector
def reset_connector() -> None:
"""Reset the default connector (closes connection and clears singleton)."""
global _default_connector
if _default_connector is not None:
_default_connector.close()
_default_connector = None
def is_snowflake_available() -> bool:
"""Return True if snowflake-connector-python is installed."""
return SNOWFLAKE_AVAILABLE
def is_snowflake_configured() -> bool:
"""Return True if Snowflake account is configured."""
try:
config = get_snowflake_config()
return config.is_configured
except Exception:
return False
# Export public API
__all__ = [
"SnowflakeConnector",
"SnowflakeConnectionError",
"SnowflakeNotConfiguredError",
"SnowflakeNotAvailableError",
"ConnectionInfo",
"get_connector",
"reset_connector",
"is_snowflake_available",
"is_snowflake_configured",
"SNOWFLAKE_AVAILABLE",
]
+331
View File
@@ -0,0 +1,331 @@
import numpy as np
import pandas as pd
import csv
import urllib.request
import io # Added for StringIO
import re # Added for regex escape and word boundaries
from typing import Optional
from core import PathConfig, default_paths
from core.logging_config import get_logger
logger = get_logger(__name__)
def drug_names(df, paths: Optional[PathConfig] = None):
# Generate dictionary to convert drug names from activity data to generic standardisation
if paths is None:
paths = default_paths
d = {}
with open(paths.drugnames_csv, 'r', newline='') as f:
reader = csv.reader(f, delimiter=',')
for drug_name, generic in reader:
d[drug_name.upper()] = generic.upper()
# Map drug names with dictionary generated earlier
df["Drug Name"] = df["Drug Name"].str.upper().map(d)
# Remove (Left eye) or (Right eye) from Drug Name, including whitespace
df["Drug Name"] = df["Drug Name"].str.replace(r'\(LEFT EYE\)', '', regex=True) # Escaped parentheses
df["Drug Name"] = df["Drug Name"].str.replace(r'\(RIGHT EYE\)', '', regex=True) # Escaped parentheses
df["Drug Name"] = df["Drug Name"].str.strip()
return df
def patient_id(df):
# Generate unique patient ID
df["UPID"] = df["Provider Code"].str[:3] + df["PersonKey"].astype(str)
return df
def compress_csv(filepath):
df = pd.read_csv(filepath)
compressed_path = filepath.replace(".csv", "_bz2.csv")
df.to_csv(compressed_path, compression="bz2", index=False)
return compressed_path
def department_identification(df, paths: Optional[PathConfig] = None):
# --- Setup ---
if paths is None:
paths = default_paths
# 1. Load directory_list.csv and prepare uppercase versions/pattern
try:
directory_df = pd.read_csv(paths.directory_list_csv)
directory_list = directory_df["directory"].dropna().astype(str).tolist()
if not directory_list:
raise ValueError("directory_list.csv is empty or contains only NA values.")
directory_list_upper = [d.upper() for d in directory_list]
# Use word boundaries (\b) to avoid partial matches within words, escape special regex chars
dir_pattern_upper = r'\b({})'.format('|'.join(map(re.escape, directory_list_upper)))
except FileNotFoundError:
logger.error(f"File not found: {paths.directory_list_csv}. Cannot extract directories.")
return df
except ValueError as e:
logger.error(f"Error loading directory list: {e}")
return df
# Simpler pattern for Primary_Source (no word boundaries)
dir_pattern_primary_simple = r'({})'.format('|'.join(map(re.escape, directory_list_upper)))
# 2. Load treatment_function_codes.csv and prepare uppercase mapping
treatment_codes = pd.read_csv(paths.treatment_function_codes_csv)
mapping_treatment_codes = dict(treatment_codes[['Code', 'Service']].values)
mapping_treatment_codes_upper = {k: str(v).upper() for k, v in mapping_treatment_codes.items()}
# 3. Load drug_directory_list.csv and parse into drug_to_valid_dirs
drug_to_valid_dirs: dict[str, set[str]] = {}
# Try pandas direct read - much simpler approach
drug_dir_df = pd.read_csv(paths.drug_directory_list_csv, skipinitialspace=True)
# Identify the drug name column (first column) and directory column (second column)
drug_col = drug_dir_df.columns[0]
dir_col = drug_dir_df.columns[1]
# Process dataframe directly
drug_to_valid_dirs = {}
for _, row in drug_dir_df.iterrows():
drug_name = str(row[drug_col]).strip().upper()
try:
# Directories are pipe-separated in the second column
dirs_str = str(row[dir_col]) if not pd.isna(row[dir_col]) else ""
dirs = {d.strip().upper() for d in dirs_str.split('|') if d.strip()}
if drug_name and dirs and drug_name.lower() != 'nan':
drug_to_valid_dirs[drug_name] = dirs
except Exception:
# Silently continue on row errors
continue
# 4. Create drug_to_single_dir map
drug_to_single_dir = {
drug: list(dirs)[0]
for drug, dirs in drug_to_valid_dirs.items()
if len(dirs) == 1
}
# --- Data Preprocessing ---
# Keep original extraction columns list
additional_detail_columns = ["Additional Detail 1", "Additional Description 1", "Additional Detail 2", "Additional Description 2",
"Additional Detail 3", "Additional Description 3", "Additional Detail 4", "Additional Description 4",
"Additional Detail 5", "Additional Description 5", "NCDR Treatment Function Name", "Treatment Function Desc"]
# 6. Convert detail columns to uppercase BEFORE extraction
for ad in additional_detail_columns:
# Check if column exists and is object/string type before applying .str
if ad in df.columns and pd.api.types.is_object_dtype(df[ad]):
df[ad] = df[ad].str.upper()
# Original extraction loop (using original case list for extraction)
# Extract directory from specified columns
directory_df = pd.read_csv(paths.directory_list_csv)
directory_list = directory_df["directory"].tolist() # Reload original case list
for ad in additional_detail_columns:
try:
# Ensure column is string type before cleaning
if pd.api.types.is_string_dtype(df[ad]):
# Extract directly from the uppercased string column
extracted = df[ad].str.extract(dir_pattern_upper, expand=False)
df.loc[extracted.index, ad] = extracted
else:
df[ad] = np.nan # Set non-string columns to NaN
except AttributeError: # Skip columns that might not exist or are not string type
df[ad] = np.nan # Ensure column exists but set to NaN if error
except Exception as e: # Catch other potential errors during extract
logger.error(f"Error processing column {ad}: {e}")
df[ad] = np.nan
# 7. Process Treatment Function Code
df["Treatment Function Code"].replace(np.nan, 0, inplace=True)
# Ensure it's int type before mapping, handle potential errors
try:
df["Treatment Function Code"] = df["Treatment Function Code"].astype(int)
except ValueError:
# Handle cases where conversion to int fails (e.g., non-numeric values)
# Try coercing errors to NaN, then fillna with 0
df["Treatment Function Code"] = pd.to_numeric(df["Treatment Function Code"], errors='coerce').fillna(0).astype(int)
df["Treatment Function Code"] = df["Treatment Function Code"].map(mapping_treatment_codes_upper)
df.rename(columns={'Treatment Function Code': 'Fallback_Source'}, inplace=True)
# Apply replacements before combining
df.replace('MEDICAL OPHTHALMOLOGY', 'OPHTHALMOLOGY', inplace=True)
# --- Single Directory Assignment ---
# 8. Apply single directory override
# Ensure Drug Name is suitable for mapping (already done in drug_names func)
df['Directory'] = df['Drug Name'].map(drug_to_single_dir)
# Initialize Directory_Source column - track which fallback level was used
df['Directory_Source'] = pd.NA
# Mark rows where single valid directory was assigned
df.loc[df['Directory'].notna(), 'Directory_Source'] = 'SINGLE_VALID_DIR'
# --- Prepare Fallback Logic ---
# 9. Create Primary source from Additional Detail 1
if 'Additional Detail 1' in df.columns:
df['Primary_Source'] = df['Additional Detail 1'].astype(pd.StringDtype())
df['Primary_Source'] = df['Primary_Source'].str.upper() # Apply upper to strings
else:
df['Primary_Source'] = pd.NA # Use pd.NA for StringDtype
# Extract actual directory name using the pattern
try:
# Use simpler pattern for primary source
df['Extracted_Primary_Dir'] = df['Primary_Source'].str.extract(dir_pattern_primary_simple, expand=False, flags=re.IGNORECASE)
df['Extracted_Fallback_Dir'] = df['Fallback_Source'].str.extract(dir_pattern_upper, expand=False, flags=re.IGNORECASE)
except Exception as e:
logger.error(f"Error during directory extraction: {e}")
# Assign NA columns if extraction fails
df['Extracted_Primary_Dir'] = pd.NA
df['Extracted_Fallback_Dir'] = pd.NA
# Strip potential whitespace from extracted directories
if 'Extracted_Primary_Dir' in df.columns:
df['Extracted_Primary_Dir'] = df['Extracted_Primary_Dir'].str.strip()
if 'Extracted_Fallback_Dir' in df.columns:
df['Extracted_Fallback_Dir'] = df['Extracted_Fallback_Dir'].str.strip()
# 10. Combine sources, prioritizing Primary_Source
# Combine EXTRACTED directories
df['Primary_Directory'] = df['Extracted_Primary_Dir'].fillna(df['Extracted_Fallback_Dir'])
# Track extraction source for Directory_Source column
# Rows where we have Extracted_Primary_Dir will use EXTRACTED_PRIMARY
# Rows where we only have Extracted_Fallback_Dir will use EXTRACTED_FALLBACK
df['_extracted_source'] = pd.NA
df.loc[df['Extracted_Primary_Dir'].notna(), '_extracted_source'] = 'EXTRACTED_PRIMARY'
df.loc[(df['Extracted_Primary_Dir'].isna()) & (df['Extracted_Fallback_Dir'].notna()), '_extracted_source'] = 'EXTRACTED_FALLBACK'
# 11. Clean up intermediate columns
df.drop(columns=['Primary_Source', 'Fallback_Source', 'Extracted_Primary_Dir', 'Extracted_Fallback_Dir'], inplace=True, errors='ignore')
# --- Identify Rows Needing Calculation ---
# 12. Filter rows where Directory is not yet assigned
df_to_process = df[df['Directory'].isnull()].copy()
# --- Calculate Most Frequent Valid Directory ---
# 13. Drop rows without a potential primary directory
df_to_process.dropna(subset=['Primary_Directory'], inplace=True)
# 14. Group and count potential directories
if not df_to_process.empty:
df_counts = df_to_process.groupby(['UPID', 'Drug Name', 'Primary_Directory'], observed=True)['Primary_Directory'].count().reset_index(name='count')
# 15. Sort by count descending
df_counts.sort_values(['UPID', 'Drug Name', 'count'], ascending=[True, True, False], inplace=True)
# 16. Define helper function
def find_first_valid_dir(group, drug_map):
drug_name = group['Drug Name'].iloc[0]
valid_dirs = drug_map.get(drug_name, set())
if not valid_dirs:
return np.nan
for dir_candidate in group['Primary_Directory']:
# Skip NA values
if pd.isna(dir_candidate):
continue
# Check if valid directory for this drug
if isinstance(dir_candidate, str) and dir_candidate in valid_dirs:
return dir_candidate
return np.nan # No valid directory found in the group
# 17. Group by UPID and Drug Name
valid_groups = df_counts.groupby(['UPID', 'Drug Name'], observed=True, group_keys=False)
# 18. Apply helper function to find the best valid directory
calculated_dirs = valid_groups.apply(lambda grp: find_first_valid_dir(grp, drug_to_valid_dirs))
# 19. Reset index to get UPID, Drug Name columns
final_mapping = calculated_dirs.reset_index()
# 20. Rename the resulting column
final_mapping.columns = ['UPID', 'Drug Name', 'Calculated_Directory']
# --- Merge Results and Finalize ---
# 21. Merge calculated directories back to the main DataFrame
df = pd.merge(df, final_mapping, on=['UPID', 'Drug Name'], how='left')
# 22. Fill NaN Directories with the calculated ones and track source
# Find rows that will be filled from Calculated_Directory
rows_to_fill = df['Directory'].isna() & df['Calculated_Directory'].notna()
# For these rows, set Directory_Source based on _extracted_source (where the calculated dir came from)
# The "calculated" directory is still derived from extraction, just via frequency analysis
df.loc[rows_to_fill, 'Directory_Source'] = df.loc[rows_to_fill, '_extracted_source'].fillna('CALCULATED_MOST_FREQ')
# Replace with the actual value of _extracted_source or fall back to CALCULATED_MOST_FREQ
# Actually, let's simplify: if we're using the calculated most frequent directory, that's CALCULATED_MOST_FREQ
df.loc[rows_to_fill, 'Directory_Source'] = 'CALCULATED_MOST_FREQ'
df['Directory'].fillna(df['Calculated_Directory'], inplace=True)
# 23. Drop temporary columns
df.drop(columns=['Calculated_Directory', 'Primary_Directory', '_extracted_source'], inplace=True, errors='ignore')
else:
# If df_to_process was empty, still need to drop temporary columns
df.drop(columns=['Primary_Directory', '_extracted_source'], inplace=True, errors='ignore')
# 24. Drop rows with missing UPID (original logic)
df['UPID'].replace('', np.nan, inplace=True) # Ensure empty strings are NaN
df_orig = df.copy() # Save before dropna for future reference if needed
df.dropna(subset=['UPID'], inplace=True)
# 25. Export rows with NA Directory to CSV for analysis (keep this for diagnostics)
na_directory_rows = df[df['Directory'].isna()].copy()
# Export to CSV if there are any NA Directory rows
if len(na_directory_rows) > 0:
na_directory_rows.to_csv(paths.na_directory_rows_csv, index=False)
# 26. FALLBACK MECHANISM 1: Infer directory based on same UPID
# Create a mapping of most frequent directory per UPID (only for UPIDs with a directory)
if len(df[df['Directory'].isna()]) > 0:
# First get valid directories per UPID
valid_upid_dirs = df[df['Directory'].notna()].groupby('UPID')['Directory'].agg(
lambda x: x.value_counts().index[0] if len(x.value_counts()) > 0 else None
).to_dict()
# Apply UPID-based inference and track source
for idx in df[df['Directory'].isna()].index:
upid = df.loc[idx, 'UPID']
if upid in valid_upid_dirs and valid_upid_dirs[upid] is not None:
df.loc[idx, 'Directory'] = valid_upid_dirs[upid]
df.loc[idx, 'Directory_Source'] = 'UPID_INFERENCE'
# 27. FALLBACK MECHANISM 2: Label remaining NA as "Undefined"
# Track rows that will be marked as Undefined
rows_undefined = df['Directory'].isna()
df.loc[rows_undefined, 'Directory_Source'] = 'UNDEFINED'
# Fill remaining NA directories with "Undefined"
df['Directory'].fillna("Undefined", inplace=True)
# 28. Return the processed DataFrame
return df
def ta_list_get(paths: Optional[PathConfig] = None):
if paths is None:
paths = default_paths
link = "https://www.nice.org.uk/Media/Default/About/what-we-do/NICE-guidance/NICE-technology-appraisals/TA%20recommendations.xlsx"
urllib.request.urlretrieve(link, paths.ta_recommendations_xlsx)
ta_db = pd.read_excel(paths.ta_recommendations_xlsx, index_col=0)
# Filter out TA's which are not Recommended or not Pharmaceutical
ta_db = ta_db[ta_db["Categorisation (for specific recommendation)"].isin(["Recommended", "Optimised"])]
ta_db = ta_db[ta_db["Technology type"] == "Pharmaceutical"]
# Amend TA001 strings to only the integer
ta_db["TA ID"] = ta_db["TA ID"].str.replace(r'\D+', '', regex=True).astype(int)
ta_db["TA ID"] = "NICE TA" + ta_db["TA ID"].astype(str)
ta_series = ta_db[["TA ID", "Indication"]].drop_duplicates()
return ta_series