refactor: reorganize repository to src/ layout
Move 6 packages (core, config, data_processing, analysis, visualization, cli) into src/ to reduce root clutter. Merge tools/data.py into data_processing/transforms.py. Move docs to docs/. Path resolution via .pth file (setup_dev.py), pytest pythonpath config, and sys.path bootstrap in rxconfig.py and CLI entry points. Clean up pyproject.toml deps (remove stale pins, add snowflake-connector-python). Fix tomllib import for Python 3.10 compatibility. All 113 tests pass.
This commit is contained in:
@@ -0,0 +1,42 @@
|
||||
# data_processing Package
|
||||
|
||||
Data layer for NHS High-Cost Drug Patient Pathway Analysis Tool.
|
||||
|
||||
## Core Responsibilities
|
||||
|
||||
**Data Pipeline:** `Snowflake → Transforms → Pathway Generation → SQLite`
|
||||
|
||||
## Key Modules
|
||||
|
||||
**transforms.py** — Core data transformations (moved from tools/data.py):
|
||||
- `patient_id()` — Creates UPID = Provider Code (first 3 chars) + PersonKey
|
||||
- `drug_names()` — Standardizes drug names via drugnames.csv lookup
|
||||
- `department_identification()` — 5-level fallback chain for directory assignment
|
||||
|
||||
**pathway_pipeline.py** — Pipeline orchestration:
|
||||
- Processes 6 date filter combinations × 2 chart types (directory + indication)
|
||||
- `fetch_and_transform_data()` — Snowflake fetch + UPID/drug/directory transforms
|
||||
- `process_pathway_for_date_filter()` — Directory charts using `generate_icicle_chart()`
|
||||
- `process_indication_pathway_for_date_filter()` — Indication charts using `generate_icicle_chart_indication()`
|
||||
- `insert_pathway_records()` — SQLite insertion with parameterized queries
|
||||
|
||||
**diagnosis_lookup.py** — GP diagnosis matching:
|
||||
- `get_patient_indication_groups()` — Batch queries Snowflake (500 patients at a time)
|
||||
- Embeds ~148 Search_Term → Cluster_ID mappings as SQL CTE
|
||||
- Returns most recent match per patient via `QUALIFY ROW_NUMBER()`
|
||||
|
||||
**database.py** — SQLite connection pooling and transaction management
|
||||
|
||||
**schema.py** — SQL schema definitions (reference tables + pathway_nodes)
|
||||
|
||||
**snowflake_connector.py** — Snowflake SSO integration via externalbrowser authenticator
|
||||
|
||||
**cache.py** — Query result caching with TTL-based invalidation
|
||||
|
||||
## Import Pattern
|
||||
|
||||
All imports use package names directly:
|
||||
```python
|
||||
from data_processing.transforms import patient_id, drug_names, department_identification
|
||||
from data_processing.pathway_pipeline import process_all_date_filters
|
||||
```
|
||||
@@ -0,0 +1,208 @@
|
||||
"""
|
||||
Data processing module for NHS High-Cost Drug Patient Pathway Analysis Tool.
|
||||
|
||||
Contains SQLite database management, data loaders, and Snowflake integration.
|
||||
Handles the migration from CSV-based storage to SQLite for improved performance.
|
||||
|
||||
Submodules:
|
||||
database: SQLite connection management and schema definitions
|
||||
loader: Data loading abstractions (CSV, SQLite, Snowflake)
|
||||
snowflake_connector: Snowflake integration with SSO authentication
|
||||
"""
|
||||
|
||||
from data_processing.database import (
|
||||
DatabaseConfig,
|
||||
DatabaseManager,
|
||||
default_db_config,
|
||||
default_db_manager,
|
||||
)
|
||||
from data_processing.schema import (
|
||||
# Reference table schemas
|
||||
REF_DRUG_NAMES_SCHEMA,
|
||||
REF_ORGANIZATIONS_SCHEMA,
|
||||
REF_DIRECTORIES_SCHEMA,
|
||||
REF_DRUG_DIRECTORY_MAP_SCHEMA,
|
||||
REF_DRUG_INDICATION_CLUSTERS_SCHEMA,
|
||||
REFERENCE_TABLES_SCHEMA,
|
||||
# Combined schema
|
||||
ALL_TABLES_SCHEMA,
|
||||
# Reference table functions
|
||||
create_reference_tables,
|
||||
drop_reference_tables,
|
||||
get_reference_table_counts,
|
||||
verify_reference_tables_exist,
|
||||
# Combined functions
|
||||
create_all_tables,
|
||||
drop_all_tables,
|
||||
get_all_table_counts,
|
||||
verify_all_tables_exist,
|
||||
)
|
||||
|
||||
# Reference data migration functions
|
||||
from data_processing.reference_data import (
|
||||
MigrationResult,
|
||||
migrate_drug_names,
|
||||
get_drug_name_counts,
|
||||
verify_drug_names_migration,
|
||||
migrate_organizations,
|
||||
get_organization_counts,
|
||||
verify_organizations_migration,
|
||||
migrate_directories,
|
||||
get_directory_counts,
|
||||
verify_directories_migration,
|
||||
migrate_drug_directory_map,
|
||||
get_drug_directory_map_counts,
|
||||
verify_drug_directory_map_migration,
|
||||
migrate_drug_indication_clusters,
|
||||
get_drug_indication_cluster_counts,
|
||||
verify_drug_indication_clusters_migration,
|
||||
)
|
||||
|
||||
# Data loader abstractions
|
||||
from data_processing.loader import (
|
||||
DataLoader,
|
||||
FileDataLoader,
|
||||
LoadResult,
|
||||
get_loader,
|
||||
REQUIRED_COLUMNS,
|
||||
OPTIONAL_COLUMNS,
|
||||
)
|
||||
|
||||
# Snowflake connector
|
||||
from data_processing.snowflake_connector import (
|
||||
SnowflakeConnector,
|
||||
SnowflakeConnectionError,
|
||||
SnowflakeNotConfiguredError,
|
||||
SnowflakeNotAvailableError,
|
||||
ConnectionInfo,
|
||||
get_connector,
|
||||
reset_connector,
|
||||
is_snowflake_available,
|
||||
is_snowflake_configured,
|
||||
SNOWFLAKE_AVAILABLE,
|
||||
)
|
||||
|
||||
# Query result caching
|
||||
from data_processing.cache import (
|
||||
QueryCache,
|
||||
CacheEntry,
|
||||
CacheStats,
|
||||
get_cache,
|
||||
reset_cache,
|
||||
is_cache_enabled,
|
||||
)
|
||||
|
||||
# Data source management with fallback chain
|
||||
from data_processing.data_source import (
|
||||
DataSourceType,
|
||||
DataSourceResult,
|
||||
SourceStatus,
|
||||
DataSourceManager,
|
||||
get_data_source_manager,
|
||||
get_data,
|
||||
reset_data_source_manager,
|
||||
)
|
||||
|
||||
# Diagnosis lookup (GP diagnosis validation)
|
||||
from data_processing.diagnosis_lookup import (
|
||||
ClusterSnomedCodes,
|
||||
IndicationValidationResult,
|
||||
DrugIndicationMatchRate,
|
||||
get_drug_clusters,
|
||||
get_drug_cluster_ids,
|
||||
get_cluster_snomed_codes,
|
||||
patient_has_indication,
|
||||
validate_indication,
|
||||
get_indication_match_rate,
|
||||
batch_validate_indications,
|
||||
get_available_clusters,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Database management
|
||||
"DatabaseConfig",
|
||||
"DatabaseManager",
|
||||
"default_db_config",
|
||||
"default_db_manager",
|
||||
# Reference table schemas
|
||||
"REF_DRUG_NAMES_SCHEMA",
|
||||
"REF_ORGANIZATIONS_SCHEMA",
|
||||
"REF_DIRECTORIES_SCHEMA",
|
||||
"REF_DRUG_DIRECTORY_MAP_SCHEMA",
|
||||
"REF_DRUG_INDICATION_CLUSTERS_SCHEMA",
|
||||
"REFERENCE_TABLES_SCHEMA",
|
||||
# Combined schema
|
||||
"ALL_TABLES_SCHEMA",
|
||||
# Reference table functions
|
||||
"create_reference_tables",
|
||||
"drop_reference_tables",
|
||||
"get_reference_table_counts",
|
||||
"verify_reference_tables_exist",
|
||||
# Combined functions
|
||||
"create_all_tables",
|
||||
"drop_all_tables",
|
||||
"get_all_table_counts",
|
||||
"verify_all_tables_exist",
|
||||
# Reference data migration
|
||||
"MigrationResult",
|
||||
"migrate_drug_names",
|
||||
"get_drug_name_counts",
|
||||
"verify_drug_names_migration",
|
||||
"migrate_organizations",
|
||||
"get_organization_counts",
|
||||
"verify_organizations_migration",
|
||||
"migrate_directories",
|
||||
"get_directory_counts",
|
||||
"verify_directories_migration",
|
||||
"migrate_drug_directory_map",
|
||||
"get_drug_directory_map_counts",
|
||||
"verify_drug_directory_map_migration",
|
||||
"migrate_drug_indication_clusters",
|
||||
"get_drug_indication_cluster_counts",
|
||||
"verify_drug_indication_clusters_migration",
|
||||
# Data loader abstractions
|
||||
"DataLoader",
|
||||
"FileDataLoader",
|
||||
"LoadResult",
|
||||
"get_loader",
|
||||
"REQUIRED_COLUMNS",
|
||||
"OPTIONAL_COLUMNS",
|
||||
# Snowflake connector
|
||||
"SnowflakeConnector",
|
||||
"SnowflakeConnectionError",
|
||||
"SnowflakeNotConfiguredError",
|
||||
"SnowflakeNotAvailableError",
|
||||
"ConnectionInfo",
|
||||
"get_connector",
|
||||
"reset_connector",
|
||||
"is_snowflake_available",
|
||||
"is_snowflake_configured",
|
||||
"SNOWFLAKE_AVAILABLE",
|
||||
# Query result caching
|
||||
"QueryCache",
|
||||
"CacheEntry",
|
||||
"CacheStats",
|
||||
"get_cache",
|
||||
"reset_cache",
|
||||
"is_cache_enabled",
|
||||
# Data source management with fallback chain
|
||||
"DataSourceType",
|
||||
"DataSourceResult",
|
||||
"SourceStatus",
|
||||
"DataSourceManager",
|
||||
"get_data_source_manager",
|
||||
"get_data",
|
||||
"reset_data_source_manager",
|
||||
# Diagnosis lookup
|
||||
"ClusterSnomedCodes",
|
||||
"IndicationValidationResult",
|
||||
"DrugIndicationMatchRate",
|
||||
"get_drug_clusters",
|
||||
"get_drug_cluster_ids",
|
||||
"get_cluster_snomed_codes",
|
||||
"patient_has_indication",
|
||||
"validate_indication",
|
||||
"get_indication_match_rate",
|
||||
"batch_validate_indications",
|
||||
"get_available_clusters",
|
||||
]
|
||||
@@ -0,0 +1,553 @@
|
||||
"""
|
||||
Query result caching module for NHS Patient Pathway Analysis.
|
||||
|
||||
Provides file-based caching for Snowflake query results with TTL-based invalidation.
|
||||
Supports different TTLs for historical data vs data including the current date.
|
||||
|
||||
Cache keys are generated from query hashes. Results are stored as compressed JSON.
|
||||
|
||||
Usage:
|
||||
from data_processing.cache import QueryCache, get_cache
|
||||
|
||||
cache = get_cache()
|
||||
|
||||
# Check for cached result
|
||||
result = cache.get(query, params)
|
||||
if result is None:
|
||||
# Execute query and cache result
|
||||
result = execute_query(query, params)
|
||||
cache.set(query, params, result, includes_current_data=False)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, date
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
import gzip
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
from config import get_snowflake_config, CacheConfig
|
||||
from core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""Metadata for a cached query result."""
|
||||
cache_key: str
|
||||
query_hash: str
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
includes_current_data: bool
|
||||
row_count: int
|
||||
file_size_bytes: int
|
||||
file_path: Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats:
|
||||
"""Statistics about the cache."""
|
||||
enabled: bool
|
||||
cache_dir: Path
|
||||
total_entries: int
|
||||
total_size_mb: float
|
||||
max_size_mb: int
|
||||
oldest_entry: Optional[datetime]
|
||||
newest_entry: Optional[datetime]
|
||||
hit_count: int
|
||||
miss_count: int
|
||||
|
||||
|
||||
class QueryCache:
|
||||
"""
|
||||
File-based cache for Snowflake query results.
|
||||
|
||||
Results are stored as gzipped JSON files with TTL-based expiration.
|
||||
Supports different TTLs for historical vs current data.
|
||||
|
||||
Attributes:
|
||||
config: CacheConfig with cache settings
|
||||
cache_dir: Path to cache directory
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[CacheConfig] = None, base_path: Optional[Path] = None):
|
||||
"""
|
||||
Initialize the query cache.
|
||||
|
||||
Args:
|
||||
config: Optional CacheConfig. If not provided, loads from snowflake.toml
|
||||
base_path: Base path for relative cache directory. Defaults to cwd.
|
||||
"""
|
||||
if config is None:
|
||||
sf_config = get_snowflake_config()
|
||||
config = sf_config.cache
|
||||
|
||||
self._config = config
|
||||
self._base_path = base_path or Path.cwd()
|
||||
|
||||
# Resolve cache directory
|
||||
cache_dir = Path(config.directory)
|
||||
if not cache_dir.is_absolute():
|
||||
cache_dir = self._base_path / cache_dir
|
||||
self._cache_dir = cache_dir
|
||||
|
||||
# Stats tracking (in-memory only, reset on restart)
|
||||
self._hit_count = 0
|
||||
self._miss_count = 0
|
||||
|
||||
# Ensure cache directory exists if enabled
|
||||
if self._config.enabled:
|
||||
self._cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@property
|
||||
def config(self) -> CacheConfig:
|
||||
"""Return the cache configuration."""
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def cache_dir(self) -> Path:
|
||||
"""Return the cache directory path."""
|
||||
return self._cache_dir
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
"""Return True if caching is enabled."""
|
||||
return self._config.enabled
|
||||
|
||||
def _generate_cache_key(self, query: str, params: Optional[tuple] = None) -> str:
|
||||
"""
|
||||
Generate a cache key from query and parameters.
|
||||
|
||||
Uses SHA256 hash of query + params to create unique key.
|
||||
"""
|
||||
# Normalize query (strip whitespace, lowercase)
|
||||
normalized_query = " ".join(query.lower().split())
|
||||
|
||||
# Combine query and params
|
||||
key_content = normalized_query
|
||||
if params:
|
||||
key_content += "|" + "|".join(str(p) for p in params)
|
||||
|
||||
# Hash to create key
|
||||
hash_obj = hashlib.sha256(key_content.encode("utf-8"))
|
||||
return hash_obj.hexdigest()[:32] # Use first 32 chars for readability
|
||||
|
||||
def _get_cache_file_path(self, cache_key: str) -> Path:
|
||||
"""Get the file path for a cache entry."""
|
||||
return self._cache_dir / f"{cache_key}.json.gz"
|
||||
|
||||
def _get_meta_file_path(self, cache_key: str) -> Path:
|
||||
"""Get the metadata file path for a cache entry."""
|
||||
return self._cache_dir / f"{cache_key}.meta.json"
|
||||
|
||||
def _is_expired(self, meta: dict) -> bool:
|
||||
"""Check if a cache entry is expired based on its metadata."""
|
||||
expires_at = datetime.fromisoformat(meta["expires_at"])
|
||||
return datetime.now() > expires_at
|
||||
|
||||
def get(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[tuple] = None,
|
||||
check_expiry: bool = True
|
||||
) -> Optional[list[dict]]:
|
||||
"""
|
||||
Get a cached query result.
|
||||
|
||||
Args:
|
||||
query: SQL query string
|
||||
params: Optional query parameters
|
||||
check_expiry: If True, returns None for expired entries
|
||||
|
||||
Returns:
|
||||
Cached result as list of dicts, or None if not cached/expired
|
||||
"""
|
||||
if not self.is_enabled:
|
||||
self._miss_count += 1
|
||||
return None
|
||||
|
||||
cache_key = self._generate_cache_key(query, params)
|
||||
cache_file = self._get_cache_file_path(cache_key)
|
||||
meta_file = self._get_meta_file_path(cache_key)
|
||||
|
||||
# Check if files exist
|
||||
if not cache_file.exists() or not meta_file.exists():
|
||||
self._miss_count += 1
|
||||
logger.debug(f"Cache miss (not found): {cache_key}")
|
||||
return None
|
||||
|
||||
# Load and check metadata
|
||||
try:
|
||||
with open(meta_file, "r", encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
if check_expiry and self._is_expired(meta):
|
||||
self._miss_count += 1
|
||||
logger.debug(f"Cache miss (expired): {cache_key}")
|
||||
return None
|
||||
|
||||
# Load cached data
|
||||
with gzip.open(cache_file, "rt", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
self._hit_count += 1
|
||||
logger.info(f"Cache hit: {cache_key} ({meta['row_count']} rows)")
|
||||
return data
|
||||
|
||||
except (json.JSONDecodeError, KeyError, OSError) as e:
|
||||
logger.warning(f"Cache read error for {cache_key}: {e}")
|
||||
self._miss_count += 1
|
||||
# Clean up corrupted entry
|
||||
self._delete_entry(cache_key)
|
||||
return None
|
||||
|
||||
def set(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[tuple],
|
||||
data: list[dict],
|
||||
includes_current_data: bool = False,
|
||||
custom_ttl_seconds: Optional[int] = None
|
||||
) -> Optional[CacheEntry]:
|
||||
"""
|
||||
Cache a query result.
|
||||
|
||||
Args:
|
||||
query: SQL query string
|
||||
params: Optional query parameters
|
||||
data: Query result as list of dicts
|
||||
includes_current_data: If True, uses shorter TTL for current data
|
||||
custom_ttl_seconds: Optional custom TTL (overrides config)
|
||||
|
||||
Returns:
|
||||
CacheEntry with metadata, or None if caching disabled/failed
|
||||
"""
|
||||
if not self.is_enabled:
|
||||
return None
|
||||
|
||||
cache_key = self._generate_cache_key(query, params)
|
||||
cache_file = self._get_cache_file_path(cache_key)
|
||||
meta_file = self._get_meta_file_path(cache_key)
|
||||
|
||||
# Determine TTL
|
||||
if custom_ttl_seconds is not None:
|
||||
ttl = custom_ttl_seconds
|
||||
elif includes_current_data:
|
||||
ttl = self._config.ttl_current_data_seconds
|
||||
else:
|
||||
ttl = self._config.ttl_seconds
|
||||
|
||||
now = datetime.now()
|
||||
expires_at = datetime.fromtimestamp(now.timestamp() + ttl)
|
||||
|
||||
try:
|
||||
# Write compressed data
|
||||
with gzip.open(cache_file, "wt", encoding="utf-8", compresslevel=6) as f:
|
||||
json.dump(data, f, default=str)
|
||||
|
||||
file_size = cache_file.stat().st_size
|
||||
|
||||
# Write metadata
|
||||
meta = {
|
||||
"cache_key": cache_key,
|
||||
"query_hash": hashlib.sha256(query.encode()).hexdigest()[:16],
|
||||
"created_at": now.isoformat(),
|
||||
"expires_at": expires_at.isoformat(),
|
||||
"includes_current_data": includes_current_data,
|
||||
"row_count": len(data),
|
||||
"file_size_bytes": file_size,
|
||||
"ttl_seconds": ttl,
|
||||
}
|
||||
|
||||
with open(meta_file, "w", encoding="utf-8") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
logger.info(f"Cached {len(data)} rows as {cache_key} (expires in {ttl}s)")
|
||||
|
||||
# Check if we need to enforce size limit
|
||||
self._enforce_size_limit()
|
||||
|
||||
return CacheEntry(
|
||||
cache_key=cache_key,
|
||||
query_hash=str(meta["query_hash"]),
|
||||
created_at=now,
|
||||
expires_at=expires_at,
|
||||
includes_current_data=includes_current_data,
|
||||
row_count=len(data),
|
||||
file_size_bytes=file_size,
|
||||
file_path=cache_file,
|
||||
)
|
||||
|
||||
except (OSError, TypeError) as e:
|
||||
logger.error(f"Failed to cache result: {e}")
|
||||
return None
|
||||
|
||||
def invalidate(self, query: str, params: Optional[tuple] = None) -> bool:
|
||||
"""
|
||||
Invalidate a specific cache entry.
|
||||
|
||||
Args:
|
||||
query: SQL query string
|
||||
params: Optional query parameters
|
||||
|
||||
Returns:
|
||||
True if entry was deleted, False if not found
|
||||
"""
|
||||
cache_key = self._generate_cache_key(query, params)
|
||||
return self._delete_entry(cache_key)
|
||||
|
||||
def _delete_entry(self, cache_key: str) -> bool:
|
||||
"""Delete a cache entry by key."""
|
||||
cache_file = self._get_cache_file_path(cache_key)
|
||||
meta_file = self._get_meta_file_path(cache_key)
|
||||
|
||||
deleted = False
|
||||
|
||||
if cache_file.exists():
|
||||
cache_file.unlink()
|
||||
deleted = True
|
||||
|
||||
if meta_file.exists():
|
||||
meta_file.unlink()
|
||||
deleted = True
|
||||
|
||||
if deleted:
|
||||
logger.debug(f"Deleted cache entry: {cache_key}")
|
||||
|
||||
return deleted
|
||||
|
||||
def clear(self) -> int:
|
||||
"""
|
||||
Clear all cache entries.
|
||||
|
||||
Returns:
|
||||
Number of entries deleted
|
||||
"""
|
||||
if not self._cache_dir.exists():
|
||||
return 0
|
||||
|
||||
count = 0
|
||||
for file in self._cache_dir.glob("*.json*"):
|
||||
try:
|
||||
file.unlink()
|
||||
count += 1
|
||||
except OSError as e:
|
||||
logger.warning(f"Failed to delete {file}: {e}")
|
||||
|
||||
# Reset stats
|
||||
self._hit_count = 0
|
||||
self._miss_count = 0
|
||||
|
||||
logger.info(f"Cleared {count} cache files")
|
||||
return count // 2 # Divide by 2 since we have .json.gz and .meta.json
|
||||
|
||||
def clear_expired(self) -> int:
|
||||
"""
|
||||
Remove expired cache entries.
|
||||
|
||||
Returns:
|
||||
Number of expired entries deleted
|
||||
"""
|
||||
if not self._cache_dir.exists():
|
||||
return 0
|
||||
|
||||
count = 0
|
||||
for meta_file in self._cache_dir.glob("*.meta.json"):
|
||||
try:
|
||||
with open(meta_file, "r", encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
if self._is_expired(meta):
|
||||
cache_key = meta_file.stem.replace(".meta", "")
|
||||
self._delete_entry(cache_key)
|
||||
count += 1
|
||||
except (OSError, json.JSONDecodeError):
|
||||
# Delete corrupted metadata files
|
||||
cache_key = meta_file.stem.replace(".meta", "")
|
||||
self._delete_entry(cache_key)
|
||||
count += 1
|
||||
|
||||
logger.info(f"Cleared {count} expired cache entries")
|
||||
return count
|
||||
|
||||
def _get_total_size_mb(self) -> float:
|
||||
"""Calculate total cache size in MB."""
|
||||
if not self._cache_dir.exists():
|
||||
return 0.0
|
||||
|
||||
total_bytes = sum(
|
||||
f.stat().st_size
|
||||
for f in self._cache_dir.glob("*")
|
||||
if f.is_file()
|
||||
)
|
||||
return total_bytes / (1024 * 1024)
|
||||
|
||||
def _enforce_size_limit(self) -> int:
|
||||
"""
|
||||
Enforce cache size limit by removing oldest entries.
|
||||
|
||||
Returns:
|
||||
Number of entries removed
|
||||
"""
|
||||
max_size_mb = self._config.max_size_mb
|
||||
current_size_mb = self._get_total_size_mb()
|
||||
|
||||
if current_size_mb <= max_size_mb:
|
||||
return 0
|
||||
|
||||
# Get all entries sorted by creation time
|
||||
entries = []
|
||||
for meta_file in self._cache_dir.glob("*.meta.json"):
|
||||
try:
|
||||
with open(meta_file, "r", encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
entries.append((
|
||||
meta_file.stem.replace(".meta", ""),
|
||||
datetime.fromisoformat(meta["created_at"]),
|
||||
meta.get("file_size_bytes", 0)
|
||||
))
|
||||
except (OSError, json.JSONDecodeError, KeyError):
|
||||
# Clean up corrupted entry
|
||||
cache_key = meta_file.stem.replace(".meta", "")
|
||||
self._delete_entry(cache_key)
|
||||
|
||||
# Sort by creation time (oldest first)
|
||||
entries.sort(key=lambda x: x[1])
|
||||
|
||||
# Remove oldest entries until under limit
|
||||
removed = 0
|
||||
size_to_remove_bytes = (current_size_mb - max_size_mb * 0.9) * 1024 * 1024 # Target 90% of limit
|
||||
removed_bytes = 0
|
||||
|
||||
for cache_key, created_at, file_size in entries:
|
||||
if removed_bytes >= size_to_remove_bytes:
|
||||
break
|
||||
|
||||
self._delete_entry(cache_key)
|
||||
removed_bytes += file_size
|
||||
removed += 1
|
||||
|
||||
logger.info(f"Removed {removed} cache entries to enforce size limit")
|
||||
return removed
|
||||
|
||||
def get_stats(self) -> CacheStats:
|
||||
"""Get cache statistics."""
|
||||
if not self._cache_dir.exists():
|
||||
return CacheStats(
|
||||
enabled=self.is_enabled,
|
||||
cache_dir=self._cache_dir,
|
||||
total_entries=0,
|
||||
total_size_mb=0.0,
|
||||
max_size_mb=self._config.max_size_mb,
|
||||
oldest_entry=None,
|
||||
newest_entry=None,
|
||||
hit_count=self._hit_count,
|
||||
miss_count=self._miss_count,
|
||||
)
|
||||
|
||||
entries = []
|
||||
for meta_file in self._cache_dir.glob("*.meta.json"):
|
||||
try:
|
||||
with open(meta_file, "r", encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
entries.append(datetime.fromisoformat(meta["created_at"]))
|
||||
except (OSError, json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
|
||||
oldest = min(entries) if entries else None
|
||||
newest = max(entries) if entries else None
|
||||
|
||||
return CacheStats(
|
||||
enabled=self.is_enabled,
|
||||
cache_dir=self._cache_dir,
|
||||
total_entries=len(entries),
|
||||
total_size_mb=self._get_total_size_mb(),
|
||||
max_size_mb=self._config.max_size_mb,
|
||||
oldest_entry=oldest,
|
||||
newest_entry=newest,
|
||||
hit_count=self._hit_count,
|
||||
miss_count=self._miss_count,
|
||||
)
|
||||
|
||||
def list_entries(self) -> list[CacheEntry]:
|
||||
"""List all cache entries with metadata."""
|
||||
if not self._cache_dir.exists():
|
||||
return []
|
||||
|
||||
entries = []
|
||||
for meta_file in self._cache_dir.glob("*.meta.json"):
|
||||
try:
|
||||
with open(meta_file, "r", encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
cache_key = meta["cache_key"]
|
||||
entries.append(CacheEntry(
|
||||
cache_key=cache_key,
|
||||
query_hash=meta.get("query_hash", ""),
|
||||
created_at=datetime.fromisoformat(meta["created_at"]),
|
||||
expires_at=datetime.fromisoformat(meta["expires_at"]),
|
||||
includes_current_data=meta.get("includes_current_data", False),
|
||||
row_count=meta.get("row_count", 0),
|
||||
file_size_bytes=meta.get("file_size_bytes", 0),
|
||||
file_path=self._get_cache_file_path(cache_key),
|
||||
))
|
||||
except (OSError, json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
|
||||
# Sort by creation time (newest first)
|
||||
entries.sort(key=lambda x: x.created_at, reverse=True)
|
||||
return entries
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
_default_cache: Optional[QueryCache] = None
|
||||
|
||||
|
||||
def get_cache(config: Optional[CacheConfig] = None) -> QueryCache:
|
||||
"""
|
||||
Get a QueryCache instance (creates singleton on first call).
|
||||
|
||||
Args:
|
||||
config: Optional CacheConfig. If provided, creates new cache with
|
||||
this config. If None, uses/creates default cache.
|
||||
|
||||
Returns:
|
||||
QueryCache instance
|
||||
"""
|
||||
global _default_cache
|
||||
|
||||
if config is not None:
|
||||
# Custom config requested, create new cache
|
||||
return QueryCache(config)
|
||||
|
||||
if _default_cache is None:
|
||||
_default_cache = QueryCache()
|
||||
|
||||
return _default_cache
|
||||
|
||||
|
||||
def reset_cache() -> None:
|
||||
"""Reset the default cache singleton."""
|
||||
global _default_cache
|
||||
_default_cache = None
|
||||
|
||||
|
||||
def is_cache_enabled() -> bool:
|
||||
"""Return True if caching is enabled in configuration."""
|
||||
config = get_snowflake_config()
|
||||
return config.cache.enabled
|
||||
|
||||
|
||||
# Export public API
|
||||
__all__ = [
|
||||
"QueryCache",
|
||||
"CacheEntry",
|
||||
"CacheStats",
|
||||
"get_cache",
|
||||
"reset_cache",
|
||||
"is_cache_enabled",
|
||||
]
|
||||
@@ -0,0 +1,932 @@
|
||||
"""
|
||||
Unified data access layer with fallback chain for NHS Patient Pathway Analysis.
|
||||
|
||||
Provides a high-level interface that automatically selects the best available data source:
|
||||
1. Cache - Returns cached results if valid and not expired
|
||||
2. Snowflake - Queries Snowflake warehouse if configured and connected
|
||||
3. Local - Falls back to SQLite database or CSV/Parquet files
|
||||
|
||||
The fallback chain handles connection errors, missing configurations, and
|
||||
unavailable services gracefully, always attempting to provide data from
|
||||
some source.
|
||||
|
||||
Usage:
|
||||
from data_processing.data_source import DataSourceManager, get_data
|
||||
|
||||
# Simple usage with automatic source selection
|
||||
result = get_data(
|
||||
start_date=date(2024, 1, 1),
|
||||
end_date=date(2024, 12, 31),
|
||||
trusts=["TRUST A", "TRUST B"],
|
||||
)
|
||||
|
||||
# Or with explicit source preference
|
||||
manager = DataSourceManager()
|
||||
result = manager.get_data(
|
||||
start_date=date(2024, 1, 1),
|
||||
end_date=date(2024, 12, 31),
|
||||
preferred_source="snowflake",
|
||||
)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date, datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional, Callable
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DataSourceType(Enum):
|
||||
"""Enumeration of available data sources."""
|
||||
CACHE = "cache"
|
||||
SNOWFLAKE = "snowflake"
|
||||
SQLITE = "sqlite"
|
||||
FILE = "file"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataSourceResult:
|
||||
"""Result from data source query.
|
||||
|
||||
Attributes:
|
||||
df: The loaded DataFrame with patient intervention data
|
||||
source_type: Which data source was used
|
||||
source_detail: Additional details about the source (e.g., file path, query hash)
|
||||
row_count: Number of rows returned
|
||||
cached: Whether the result came from cache
|
||||
from_fallback: Whether a fallback source was used
|
||||
load_time_seconds: Time taken to load data
|
||||
warnings: Any warnings generated during loading
|
||||
"""
|
||||
df: pd.DataFrame
|
||||
source_type: DataSourceType
|
||||
source_detail: str = ""
|
||||
row_count: int = 0
|
||||
cached: bool = False
|
||||
from_fallback: bool = False
|
||||
load_time_seconds: float = 0.0
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.row_count == 0 and self.df is not None:
|
||||
self.row_count = len(self.df)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SourceStatus:
|
||||
"""Status of a data source.
|
||||
|
||||
Attributes:
|
||||
source_type: The type of data source
|
||||
available: Whether the source is available
|
||||
configured: Whether the source is properly configured
|
||||
message: Status message explaining the state
|
||||
last_checked: When the status was last checked
|
||||
"""
|
||||
source_type: DataSourceType
|
||||
available: bool = False
|
||||
configured: bool = False
|
||||
message: str = ""
|
||||
last_checked: Optional[datetime] = None
|
||||
|
||||
|
||||
class DataSourceManager:
|
||||
"""
|
||||
Manages data access with automatic fallback between sources.
|
||||
|
||||
The manager attempts to retrieve data from sources in order of preference:
|
||||
1. Cache (if enabled and has valid cached data)
|
||||
2. Snowflake (if configured and connected)
|
||||
3. SQLite (if database exists with data)
|
||||
4. Local files (CSV/Parquet)
|
||||
|
||||
Attributes:
|
||||
cache_enabled: Whether to use caching
|
||||
local_file_path: Path to local CSV/Parquet file (optional fallback)
|
||||
sqlite_db_path: Path to SQLite database (optional)
|
||||
|
||||
Example:
|
||||
manager = DataSourceManager()
|
||||
|
||||
# Check what sources are available
|
||||
status = manager.check_all_sources()
|
||||
for s in status:
|
||||
print(f"{s.source_type.value}: {s.message}")
|
||||
|
||||
# Get data with automatic fallback
|
||||
result = manager.get_data(
|
||||
start_date=date(2024, 1, 1),
|
||||
end_date=date(2024, 6, 30),
|
||||
)
|
||||
print(f"Got {result.row_count} rows from {result.source_type.value}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_enabled: bool = True,
|
||||
local_file_path: Optional[Path | str] = None,
|
||||
sqlite_db_path: Optional[Path | str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the data source manager.
|
||||
|
||||
Args:
|
||||
cache_enabled: Whether to check cache before querying (default True)
|
||||
local_file_path: Path to local CSV/Parquet file for file fallback
|
||||
sqlite_db_path: Path to SQLite database (uses default if None)
|
||||
"""
|
||||
self._cache_enabled = cache_enabled
|
||||
self._local_file_path = Path(local_file_path) if local_file_path else None
|
||||
self._sqlite_db_path = Path(sqlite_db_path) if sqlite_db_path else None
|
||||
self._source_status: dict[DataSourceType, SourceStatus] = {}
|
||||
|
||||
@property
|
||||
def cache_enabled(self) -> bool:
|
||||
"""Return whether caching is enabled."""
|
||||
return self._cache_enabled
|
||||
|
||||
@cache_enabled.setter
|
||||
def cache_enabled(self, value: bool):
|
||||
"""Set whether caching is enabled."""
|
||||
self._cache_enabled = value
|
||||
|
||||
def _check_cache_status(self) -> SourceStatus:
|
||||
"""Check if cache is available."""
|
||||
try:
|
||||
from data_processing.cache import is_cache_enabled, get_cache
|
||||
|
||||
if not is_cache_enabled():
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.CACHE,
|
||||
available=False,
|
||||
configured=False,
|
||||
message="Cache disabled in configuration",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
cache = get_cache()
|
||||
stats = cache.get_stats()
|
||||
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.CACHE,
|
||||
available=True,
|
||||
configured=True,
|
||||
message=f"Cache enabled ({stats.total_entries} entries, {stats.total_size_mb:.1f}MB)",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
except Exception as e:
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.CACHE,
|
||||
available=False,
|
||||
configured=False,
|
||||
message=f"Cache error: {e}",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
def _check_snowflake_status(self) -> SourceStatus:
|
||||
"""Check if Snowflake is available and configured."""
|
||||
try:
|
||||
from data_processing.snowflake_connector import (
|
||||
is_snowflake_available,
|
||||
is_snowflake_configured,
|
||||
)
|
||||
|
||||
if not is_snowflake_available():
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SNOWFLAKE,
|
||||
available=False,
|
||||
configured=False,
|
||||
message="snowflake-connector-python not installed",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
if not is_snowflake_configured():
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SNOWFLAKE,
|
||||
available=True,
|
||||
configured=False,
|
||||
message="Snowflake account not configured in config/snowflake.toml",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SNOWFLAKE,
|
||||
available=True,
|
||||
configured=True,
|
||||
message="Snowflake configured and ready",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
except Exception as e:
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SNOWFLAKE,
|
||||
available=False,
|
||||
configured=False,
|
||||
message=f"Snowflake error: {e}",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
def _check_sqlite_status(self) -> SourceStatus:
|
||||
"""Check if SQLite database is available with pathway data."""
|
||||
try:
|
||||
from data_processing.database import default_db_config
|
||||
|
||||
db_path = self._sqlite_db_path or Path(default_db_config.db_path)
|
||||
|
||||
if not db_path.exists():
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SQLITE,
|
||||
available=False,
|
||||
configured=True,
|
||||
message=f"Database not found: {db_path}",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
from data_processing.database import DatabaseManager, DatabaseConfig
|
||||
|
||||
config = DatabaseConfig(db_path=db_path)
|
||||
manager = DatabaseManager(config)
|
||||
|
||||
if not manager.table_exists("pathway_nodes"):
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SQLITE,
|
||||
available=False,
|
||||
configured=True,
|
||||
message="pathway_nodes table not found",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
count = manager.get_table_count("pathway_nodes")
|
||||
if count == 0:
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SQLITE,
|
||||
available=False,
|
||||
configured=True,
|
||||
message="pathway_nodes table is empty",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SQLITE,
|
||||
available=True,
|
||||
configured=True,
|
||||
message=f"SQLite database ready ({count:,} pathway nodes)",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
except Exception as e:
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SQLITE,
|
||||
available=False,
|
||||
configured=False,
|
||||
message=f"SQLite error: {e}",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
def _check_file_status(self) -> SourceStatus:
|
||||
"""Check if local file is available."""
|
||||
if self._local_file_path is None:
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.FILE,
|
||||
available=False,
|
||||
configured=False,
|
||||
message="No local file path configured",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
if not self._local_file_path.exists():
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.FILE,
|
||||
available=False,
|
||||
configured=True,
|
||||
message=f"File not found: {self._local_file_path}",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
size_mb = self._local_file_path.stat().st_size / (1024 * 1024)
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.FILE,
|
||||
available=True,
|
||||
configured=True,
|
||||
message=f"Local file ready: {self._local_file_path.name} ({size_mb:.1f}MB)",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
def check_source_status(self, source_type: DataSourceType) -> SourceStatus:
|
||||
"""
|
||||
Check the status of a specific data source.
|
||||
|
||||
Args:
|
||||
source_type: The type of source to check
|
||||
|
||||
Returns:
|
||||
SourceStatus with current availability information
|
||||
"""
|
||||
if source_type == DataSourceType.CACHE:
|
||||
return self._check_cache_status()
|
||||
elif source_type == DataSourceType.SNOWFLAKE:
|
||||
return self._check_snowflake_status()
|
||||
elif source_type == DataSourceType.SQLITE:
|
||||
return self._check_sqlite_status()
|
||||
elif source_type == DataSourceType.FILE:
|
||||
return self._check_file_status()
|
||||
else:
|
||||
return SourceStatus(
|
||||
source_type=source_type,
|
||||
available=False,
|
||||
configured=False,
|
||||
message=f"Unknown source type: {source_type}",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
def check_all_sources(self) -> list[SourceStatus]:
|
||||
"""
|
||||
Check the status of all data sources.
|
||||
|
||||
Returns:
|
||||
List of SourceStatus for each source type
|
||||
"""
|
||||
statuses = []
|
||||
for source_type in DataSourceType:
|
||||
status = self.check_source_status(source_type)
|
||||
self._source_status[source_type] = status
|
||||
statuses.append(status)
|
||||
return statuses
|
||||
|
||||
def _build_cache_key_params(
|
||||
self,
|
||||
start_date: Optional[date],
|
||||
end_date: Optional[date],
|
||||
trusts: Optional[list[str]],
|
||||
drugs: Optional[list[str]],
|
||||
directories: Optional[list[str]],
|
||||
) -> tuple[str, tuple]:
|
||||
"""Build a cache-compatible query string and params for the filter criteria."""
|
||||
# Create a canonical representation for caching
|
||||
query_parts = ["SELECT * FROM activity_data"]
|
||||
params = []
|
||||
|
||||
conditions = []
|
||||
if start_date:
|
||||
conditions.append("start_date >= ?")
|
||||
params.append(str(start_date))
|
||||
if end_date:
|
||||
conditions.append("end_date <= ?")
|
||||
params.append(str(end_date))
|
||||
if trusts:
|
||||
placeholders = ",".join(["?"] * len(trusts))
|
||||
conditions.append(f"trust IN ({placeholders})")
|
||||
params.extend(sorted(trusts))
|
||||
if drugs:
|
||||
placeholders = ",".join(["?"] * len(drugs))
|
||||
conditions.append(f"drug IN ({placeholders})")
|
||||
params.extend(sorted(drugs))
|
||||
if directories:
|
||||
placeholders = ",".join(["?"] * len(directories))
|
||||
conditions.append(f"directory IN ({placeholders})")
|
||||
params.extend(sorted(directories))
|
||||
|
||||
if conditions:
|
||||
query_parts.append("WHERE " + " AND ".join(conditions))
|
||||
|
||||
query = " ".join(query_parts)
|
||||
return query, tuple(params)
|
||||
|
||||
def _try_cache(
|
||||
self,
|
||||
start_date: Optional[date],
|
||||
end_date: Optional[date],
|
||||
trusts: Optional[list[str]],
|
||||
drugs: Optional[list[str]],
|
||||
directories: Optional[list[str]],
|
||||
) -> Optional[DataSourceResult]:
|
||||
"""Try to get data from cache."""
|
||||
if not self._cache_enabled:
|
||||
return None
|
||||
|
||||
try:
|
||||
from data_processing.cache import get_cache
|
||||
|
||||
cache = get_cache()
|
||||
if not cache.is_enabled:
|
||||
return None
|
||||
|
||||
query, params = self._build_cache_key_params(
|
||||
start_date, end_date, trusts, drugs, directories
|
||||
)
|
||||
|
||||
cached_data = cache.get(query, params)
|
||||
if cached_data is None:
|
||||
logger.debug("Cache miss")
|
||||
return None
|
||||
|
||||
# Convert cached data back to DataFrame
|
||||
df = pd.DataFrame(cached_data)
|
||||
|
||||
# Convert date columns
|
||||
if 'Intervention Date' in df.columns:
|
||||
df['Intervention Date'] = pd.to_datetime(df['Intervention Date'])
|
||||
|
||||
logger.info(f"Cache hit: {len(df)} rows")
|
||||
|
||||
return DataSourceResult(
|
||||
df=df,
|
||||
source_type=DataSourceType.CACHE,
|
||||
source_detail=f"cache_key={query[:50]}...",
|
||||
row_count=len(df),
|
||||
cached=True,
|
||||
from_fallback=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache lookup failed: {e}")
|
||||
return None
|
||||
|
||||
def _try_snowflake(
|
||||
self,
|
||||
start_date: Optional[date],
|
||||
end_date: Optional[date],
|
||||
trusts: Optional[list[str]],
|
||||
drugs: Optional[list[str]],
|
||||
directories: Optional[list[str]],
|
||||
progress_callback: Optional[Callable[[int, int], None]] = None,
|
||||
) -> Optional[DataSourceResult]:
|
||||
"""Try to get data from Snowflake."""
|
||||
import time
|
||||
|
||||
try:
|
||||
from data_processing.snowflake_connector import (
|
||||
is_snowflake_available,
|
||||
is_snowflake_configured,
|
||||
get_connector,
|
||||
SnowflakeConnectionError,
|
||||
)
|
||||
|
||||
if not is_snowflake_available():
|
||||
logger.debug("Snowflake connector not installed")
|
||||
return None
|
||||
|
||||
if not is_snowflake_configured():
|
||||
logger.debug("Snowflake not configured")
|
||||
return None
|
||||
|
||||
# Get connector and fetch data
|
||||
connector = get_connector()
|
||||
logger.info("Fetching data from Snowflake...")
|
||||
start_time = time.time()
|
||||
|
||||
# Fetch activity data from Snowflake
|
||||
# Note: provider_codes filter not directly supported yet - would need trust name to code mapping
|
||||
rows = connector.fetch_activity_data(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
provider_codes=None, # TODO: map trust names to provider codes if needed
|
||||
)
|
||||
|
||||
if not rows:
|
||||
logger.warning("Snowflake returned no data")
|
||||
return None
|
||||
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(rows)
|
||||
load_time = time.time() - start_time
|
||||
|
||||
logger.info(f"Snowflake loaded {len(df)} rows in {load_time:.2f}s")
|
||||
|
||||
# Apply local transformations to match expected format
|
||||
# (patient_id, drug_names, department_identification)
|
||||
from data_processing.transforms import patient_id, drug_names, department_identification
|
||||
from core import default_paths
|
||||
|
||||
df = patient_id(df)
|
||||
df = drug_names(df, paths=default_paths)
|
||||
df = department_identification(df, paths=default_paths)
|
||||
|
||||
# Apply additional filters if provided
|
||||
if trusts and 'OrganisationName' in df.columns:
|
||||
df = df[df['OrganisationName'].isin(trusts)]
|
||||
if drugs and 'Drug Name' in df.columns:
|
||||
df = df[df['Drug Name'].isin(drugs)]
|
||||
if directories and 'Directory' in df.columns:
|
||||
df = df[df['Directory'].isin(directories)]
|
||||
|
||||
return DataSourceResult(
|
||||
df=df,
|
||||
source_type=DataSourceType.SNOWFLAKE,
|
||||
source_detail="DATA_HUB.CDM.Acute__Conmon__PatientLevelDrugs",
|
||||
row_count=len(df),
|
||||
cached=False,
|
||||
from_fallback=False,
|
||||
load_time_seconds=load_time,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Snowflake query failed: {e}")
|
||||
return None
|
||||
|
||||
def _try_sqlite(
|
||||
self,
|
||||
start_date: Optional[date],
|
||||
end_date: Optional[date],
|
||||
trusts: Optional[list[str]],
|
||||
drugs: Optional[list[str]],
|
||||
directories: Optional[list[str]],
|
||||
) -> Optional[DataSourceResult]:
|
||||
"""Try to get data from SQLite.
|
||||
|
||||
Note: Raw intervention data is no longer stored in SQLite.
|
||||
The app now uses pre-computed pathway_nodes via load_pathway_data().
|
||||
This fallback is retained for interface compatibility but always returns None.
|
||||
"""
|
||||
logger.debug("SQLite raw data fallback skipped (fact_interventions removed)")
|
||||
return None
|
||||
|
||||
def _try_file(
|
||||
self,
|
||||
start_date: Optional[date],
|
||||
end_date: Optional[date],
|
||||
trusts: Optional[list[str]],
|
||||
drugs: Optional[list[str]],
|
||||
directories: Optional[list[str]],
|
||||
) -> Optional[DataSourceResult]:
|
||||
"""Try to get data from local file."""
|
||||
import time
|
||||
|
||||
if self._local_file_path is None:
|
||||
logger.debug("No local file configured")
|
||||
return None
|
||||
|
||||
try:
|
||||
from data_processing.loader import FileDataLoader
|
||||
|
||||
loader = FileDataLoader(file_path=self._local_file_path)
|
||||
|
||||
is_valid, msg = loader.validate_source()
|
||||
if not is_valid:
|
||||
logger.debug(f"Local file not available: {msg}")
|
||||
return None
|
||||
|
||||
start_time = time.time()
|
||||
result = loader.load()
|
||||
df = result.df
|
||||
|
||||
# Apply filters (file loader loads all data, then we filter)
|
||||
if start_date and 'Intervention Date' in df.columns:
|
||||
df = df[df['Intervention Date'] >= pd.Timestamp(start_date)]
|
||||
if end_date and 'Intervention Date' in df.columns:
|
||||
df = df[df['Intervention Date'] < pd.Timestamp(end_date)]
|
||||
if trusts and 'OrganisationName' in df.columns:
|
||||
df = df[df['OrganisationName'].isin(trusts)]
|
||||
if drugs and 'Drug Name' in df.columns:
|
||||
df = df[df['Drug Name'].isin(drugs)]
|
||||
if directories and 'Directory' in df.columns:
|
||||
df = df[df['Directory'].isin(directories)]
|
||||
|
||||
load_time = time.time() - start_time
|
||||
|
||||
logger.info(f"File loaded and filtered: {len(df)} rows in {load_time:.2f}s")
|
||||
|
||||
return DataSourceResult(
|
||||
df=df,
|
||||
source_type=DataSourceType.FILE,
|
||||
source_detail=str(self._local_file_path),
|
||||
row_count=len(df),
|
||||
cached=False,
|
||||
from_fallback=True,
|
||||
load_time_seconds=load_time,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"File load failed: {e}")
|
||||
return None
|
||||
|
||||
def get_data(
|
||||
self,
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
trusts: Optional[list[str]] = None,
|
||||
drugs: Optional[list[str]] = None,
|
||||
directories: Optional[list[str]] = None,
|
||||
preferred_source: Optional[str] = None,
|
||||
skip_cache: bool = False,
|
||||
progress_callback: Optional[Callable[[int, int], None]] = None,
|
||||
) -> DataSourceResult:
|
||||
"""
|
||||
Get patient intervention data from the best available source.
|
||||
|
||||
The fallback chain is: Cache → Snowflake → SQLite → File
|
||||
|
||||
Args:
|
||||
start_date: Optional start date for filtering (inclusive)
|
||||
end_date: Optional end date for filtering (exclusive)
|
||||
trusts: Optional list of trust names to filter
|
||||
drugs: Optional list of drug names to filter
|
||||
directories: Optional list of directories to filter
|
||||
preferred_source: Optional preferred source ("snowflake", "sqlite", "file")
|
||||
skip_cache: If True, bypass cache and query source directly
|
||||
progress_callback: Optional callback(current, total) for progress updates
|
||||
|
||||
Returns:
|
||||
DataSourceResult with the loaded data and metadata
|
||||
|
||||
Raises:
|
||||
ValueError: If no data source is available or all sources fail
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
warnings = []
|
||||
|
||||
# If preferred source specified, try that first
|
||||
if preferred_source:
|
||||
preferred = preferred_source.lower()
|
||||
if preferred == "snowflake":
|
||||
result = self._try_snowflake(
|
||||
start_date, end_date, trusts, drugs, directories, progress_callback
|
||||
)
|
||||
if result:
|
||||
result.load_time_seconds = time.time() - start_time
|
||||
return result
|
||||
warnings.append("Preferred source 'snowflake' unavailable")
|
||||
|
||||
elif preferred == "sqlite":
|
||||
result = self._try_sqlite(
|
||||
start_date, end_date, trusts, drugs, directories
|
||||
)
|
||||
if result:
|
||||
result.load_time_seconds = time.time() - start_time
|
||||
return result
|
||||
warnings.append("Preferred source 'sqlite' unavailable")
|
||||
|
||||
elif preferred == "file":
|
||||
result = self._try_file(
|
||||
start_date, end_date, trusts, drugs, directories
|
||||
)
|
||||
if result:
|
||||
result.load_time_seconds = time.time() - start_time
|
||||
return result
|
||||
warnings.append("Preferred source 'file' unavailable")
|
||||
|
||||
# Standard fallback chain: cache → snowflake → sqlite → file
|
||||
|
||||
# 1. Try cache first (unless skipped)
|
||||
if not skip_cache:
|
||||
result = self._try_cache(
|
||||
start_date, end_date, trusts, drugs, directories
|
||||
)
|
||||
if result:
|
||||
result.load_time_seconds = time.time() - start_time
|
||||
return result
|
||||
|
||||
# 2. Try Snowflake
|
||||
result = self._try_snowflake(
|
||||
start_date, end_date, trusts, drugs, directories, progress_callback
|
||||
)
|
||||
if result:
|
||||
# Cache the result for future queries
|
||||
if self._cache_enabled:
|
||||
self._cache_result(
|
||||
result.df,
|
||||
start_date, end_date, trusts, drugs, directories,
|
||||
includes_current_data=end_date is None or end_date >= date.today()
|
||||
)
|
||||
result.load_time_seconds = time.time() - start_time
|
||||
return result
|
||||
|
||||
# 3. Try SQLite
|
||||
result = self._try_sqlite(
|
||||
start_date, end_date, trusts, drugs, directories
|
||||
)
|
||||
if result:
|
||||
result.from_fallback = True # Mark as fallback since Snowflake wasn't used
|
||||
result.load_time_seconds = time.time() - start_time
|
||||
if warnings:
|
||||
result.warnings.extend(warnings)
|
||||
return result
|
||||
|
||||
# 4. Try local file
|
||||
result = self._try_file(
|
||||
start_date, end_date, trusts, drugs, directories
|
||||
)
|
||||
if result:
|
||||
result.from_fallback = True
|
||||
result.load_time_seconds = time.time() - start_time
|
||||
if warnings:
|
||||
result.warnings.extend(warnings)
|
||||
return result
|
||||
|
||||
# All sources failed
|
||||
source_status = self.check_all_sources()
|
||||
status_msg = "; ".join(
|
||||
f"{s.source_type.value}: {s.message}" for s in source_status
|
||||
)
|
||||
raise ValueError(f"No data source available. Status: {status_msg}")
|
||||
|
||||
def _cache_result(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
start_date: Optional[date],
|
||||
end_date: Optional[date],
|
||||
trusts: Optional[list[str]],
|
||||
drugs: Optional[list[str]],
|
||||
directories: Optional[list[str]],
|
||||
includes_current_data: bool = False,
|
||||
) -> bool:
|
||||
"""Cache a query result for future use."""
|
||||
try:
|
||||
from data_processing.cache import get_cache
|
||||
|
||||
cache = get_cache()
|
||||
if not cache.is_enabled:
|
||||
return False
|
||||
|
||||
query, params = self._build_cache_key_params(
|
||||
start_date, end_date, trusts, drugs, directories
|
||||
)
|
||||
|
||||
# Convert DataFrame to list of dicts for caching
|
||||
# Convert datetime columns to strings for JSON serialization
|
||||
df_copy = df.copy()
|
||||
for col in df_copy.columns:
|
||||
if pd.api.types.is_datetime64_any_dtype(df_copy[col]):
|
||||
df_copy[col] = df_copy[col].astype(str)
|
||||
|
||||
data = df_copy.to_dict(orient='records')
|
||||
|
||||
entry = cache.set(
|
||||
query, params, data,
|
||||
includes_current_data=includes_current_data
|
||||
)
|
||||
|
||||
if entry:
|
||||
logger.info(f"Cached {len(data)} rows (key={entry.cache_key[:16]}...)")
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache result: {e}")
|
||||
return False
|
||||
|
||||
def clear_cache(self) -> int:
|
||||
"""
|
||||
Clear all cached data.
|
||||
|
||||
Returns:
|
||||
Number of cache entries cleared
|
||||
"""
|
||||
try:
|
||||
from data_processing.cache import get_cache
|
||||
cache = get_cache()
|
||||
return cache.clear()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clear cache: {e}")
|
||||
return 0
|
||||
|
||||
def refresh_from_snowflake(
|
||||
self,
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
trusts: Optional[list[str]] = None,
|
||||
drugs: Optional[list[str]] = None,
|
||||
directories: Optional[list[str]] = None,
|
||||
progress_callback: Optional[Callable[[int, int], None]] = None,
|
||||
) -> DataSourceResult:
|
||||
"""
|
||||
Force a refresh from Snowflake, bypassing cache and other sources.
|
||||
|
||||
This method specifically queries Snowflake and will fail if Snowflake
|
||||
is not available or not configured.
|
||||
|
||||
Args:
|
||||
start_date: Optional start date for filtering
|
||||
end_date: Optional end date for filtering
|
||||
trusts: Optional list of trust names
|
||||
drugs: Optional list of drug names
|
||||
directories: Optional list of directories
|
||||
progress_callback: Optional progress callback
|
||||
|
||||
Returns:
|
||||
DataSourceResult from Snowflake
|
||||
|
||||
Raises:
|
||||
ValueError: If Snowflake is not available or query fails
|
||||
"""
|
||||
from data_processing.snowflake_connector import (
|
||||
is_snowflake_available,
|
||||
is_snowflake_configured,
|
||||
)
|
||||
|
||||
if not is_snowflake_available():
|
||||
raise ValueError("Snowflake connector not installed")
|
||||
|
||||
if not is_snowflake_configured():
|
||||
raise ValueError("Snowflake not configured - edit config/snowflake.toml")
|
||||
|
||||
result = self._try_snowflake(
|
||||
start_date, end_date, trusts, drugs, directories, progress_callback
|
||||
)
|
||||
|
||||
if result is None:
|
||||
raise ValueError("Snowflake query failed - check logs for details")
|
||||
|
||||
# Cache the fresh result
|
||||
if self._cache_enabled:
|
||||
self._cache_result(
|
||||
result.df,
|
||||
start_date, end_date, trusts, drugs, directories,
|
||||
includes_current_data=end_date is None or end_date >= date.today()
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# Module-level singleton and convenience functions
|
||||
_default_manager: Optional[DataSourceManager] = None
|
||||
|
||||
|
||||
def get_data_source_manager(
|
||||
cache_enabled: bool = True,
|
||||
local_file_path: Optional[Path | str] = None,
|
||||
sqlite_db_path: Optional[Path | str] = None,
|
||||
) -> DataSourceManager:
|
||||
"""
|
||||
Get a DataSourceManager instance.
|
||||
|
||||
Args:
|
||||
cache_enabled: Whether to enable caching
|
||||
local_file_path: Optional path to local CSV/Parquet file
|
||||
sqlite_db_path: Optional path to SQLite database
|
||||
|
||||
Returns:
|
||||
DataSourceManager instance
|
||||
"""
|
||||
global _default_manager
|
||||
|
||||
# If custom paths provided, create a new manager
|
||||
if local_file_path or sqlite_db_path:
|
||||
return DataSourceManager(
|
||||
cache_enabled=cache_enabled,
|
||||
local_file_path=local_file_path,
|
||||
sqlite_db_path=sqlite_db_path,
|
||||
)
|
||||
|
||||
# Otherwise use/create singleton
|
||||
if _default_manager is None:
|
||||
_default_manager = DataSourceManager(cache_enabled=cache_enabled)
|
||||
|
||||
return _default_manager
|
||||
|
||||
|
||||
def get_data(
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
trusts: Optional[list[str]] = None,
|
||||
drugs: Optional[list[str]] = None,
|
||||
directories: Optional[list[str]] = None,
|
||||
preferred_source: Optional[str] = None,
|
||||
skip_cache: bool = False,
|
||||
) -> DataSourceResult:
|
||||
"""
|
||||
Convenience function to get data using the default manager.
|
||||
|
||||
Args:
|
||||
start_date: Optional start date for filtering
|
||||
end_date: Optional end date for filtering
|
||||
trusts: Optional list of trust names
|
||||
drugs: Optional list of drug names
|
||||
directories: Optional list of directories
|
||||
preferred_source: Optional preferred source
|
||||
skip_cache: If True, bypass cache
|
||||
|
||||
Returns:
|
||||
DataSourceResult with loaded data
|
||||
"""
|
||||
manager = get_data_source_manager()
|
||||
return manager.get_data(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
trusts=trusts,
|
||||
drugs=drugs,
|
||||
directories=directories,
|
||||
preferred_source=preferred_source,
|
||||
skip_cache=skip_cache,
|
||||
)
|
||||
|
||||
|
||||
def reset_data_source_manager() -> None:
|
||||
"""Reset the default data source manager singleton."""
|
||||
global _default_manager
|
||||
_default_manager = None
|
||||
|
||||
|
||||
# Export public API
|
||||
__all__ = [
|
||||
"DataSourceType",
|
||||
"DataSourceResult",
|
||||
"SourceStatus",
|
||||
"DataSourceManager",
|
||||
"get_data_source_manager",
|
||||
"get_data",
|
||||
"reset_data_source_manager",
|
||||
]
|
||||
@@ -0,0 +1,239 @@
|
||||
"""
|
||||
SQLite database connection management for NHS High-Cost Drug Patient Pathway Analysis Tool.
|
||||
|
||||
Provides connection management, schema initialization, and common database operations.
|
||||
Uses context manager pattern for safe resource handling.
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Optional, Generator, Literal
|
||||
|
||||
from core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DatabaseConfig:
|
||||
"""
|
||||
Configuration for SQLite database location and connection parameters.
|
||||
|
||||
Attributes:
|
||||
db_path: Path to the SQLite database file
|
||||
timeout: Connection timeout in seconds (default: 30)
|
||||
isolation_level: Transaction isolation level (default: None for autocommit)
|
||||
"""
|
||||
|
||||
DEFAULT_DB_NAME = "pathways.db"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: Optional[Path] = None,
|
||||
data_dir: Optional[Path] = None,
|
||||
timeout: float = 30.0,
|
||||
isolation_level: Optional[Literal['DEFERRED', 'EXCLUSIVE', 'IMMEDIATE']] = None
|
||||
):
|
||||
"""
|
||||
Initialize database configuration.
|
||||
|
||||
Args:
|
||||
db_path: Full path to database file. If None, uses data_dir/DEFAULT_DB_NAME.
|
||||
data_dir: Directory to place database in. Defaults to ./data/
|
||||
timeout: Connection timeout in seconds.
|
||||
isolation_level: Transaction isolation level. None = autocommit.
|
||||
"""
|
||||
if db_path is not None:
|
||||
self.db_path = Path(db_path)
|
||||
elif data_dir is not None:
|
||||
self.db_path = Path(data_dir) / self.DEFAULT_DB_NAME
|
||||
else:
|
||||
self.db_path = Path("./data") / self.DEFAULT_DB_NAME
|
||||
|
||||
self.timeout = timeout
|
||||
self.isolation_level = isolation_level
|
||||
|
||||
def validate(self) -> list[str]:
|
||||
"""
|
||||
Validate database configuration.
|
||||
|
||||
Returns:
|
||||
List of error messages. Empty list means configuration is valid.
|
||||
"""
|
||||
errors = []
|
||||
|
||||
# Check parent directory exists
|
||||
parent_dir = self.db_path.parent
|
||||
if not parent_dir.exists():
|
||||
errors.append(f"Database directory does not exist: {parent_dir}")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""
|
||||
Manages SQLite database connections and operations.
|
||||
|
||||
Provides context manager for safe connection handling and methods
|
||||
for common database operations.
|
||||
|
||||
Usage:
|
||||
db_manager = DatabaseManager()
|
||||
|
||||
# Using context manager (recommended)
|
||||
with db_manager.get_connection() as conn:
|
||||
cursor = conn.execute("SELECT * FROM ref_drug_names")
|
||||
results = cursor.fetchall()
|
||||
|
||||
# Or get a managed connection for longer operations
|
||||
conn = db_manager.connect()
|
||||
try:
|
||||
# ... do work ...
|
||||
finally:
|
||||
conn.close()
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[DatabaseConfig] = None):
|
||||
"""
|
||||
Initialize the database manager.
|
||||
|
||||
Args:
|
||||
config: Database configuration. If None, uses default configuration.
|
||||
"""
|
||||
self.config = config or DatabaseConfig()
|
||||
self._connection: Optional[sqlite3.Connection] = None
|
||||
|
||||
@property
|
||||
def db_path(self) -> Path:
|
||||
"""Path to the SQLite database file."""
|
||||
return self.config.db_path
|
||||
|
||||
@property
|
||||
def exists(self) -> bool:
|
||||
"""Check if the database file exists."""
|
||||
return self.db_path.exists()
|
||||
|
||||
def connect(self) -> sqlite3.Connection:
|
||||
"""
|
||||
Create a new database connection.
|
||||
|
||||
Returns:
|
||||
sqlite3.Connection: New database connection.
|
||||
|
||||
Note:
|
||||
The caller is responsible for closing the connection.
|
||||
Consider using get_connection() context manager instead.
|
||||
"""
|
||||
conn = sqlite3.connect(
|
||||
str(self.db_path),
|
||||
timeout=self.config.timeout,
|
||||
isolation_level=self.config.isolation_level
|
||||
)
|
||||
# Enable foreign key support
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
# Return rows as sqlite3.Row for dict-like access
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self) -> Generator[sqlite3.Connection, None, None]:
|
||||
"""
|
||||
Context manager for database connections.
|
||||
|
||||
Yields:
|
||||
sqlite3.Connection: Database connection.
|
||||
|
||||
Example:
|
||||
with db_manager.get_connection() as conn:
|
||||
conn.execute("INSERT INTO table VALUES (?)", (value,))
|
||||
conn.commit()
|
||||
"""
|
||||
conn = self.connect()
|
||||
try:
|
||||
yield conn
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
@contextmanager
|
||||
def get_transaction(self) -> Generator[sqlite3.Connection, None, None]:
|
||||
"""
|
||||
Context manager for transactional operations.
|
||||
|
||||
Automatically commits on success, rolls back on exception.
|
||||
|
||||
Yields:
|
||||
sqlite3.Connection: Database connection in transaction mode.
|
||||
|
||||
Example:
|
||||
with db_manager.get_transaction() as conn:
|
||||
conn.execute("INSERT INTO table VALUES (?)", (value1,))
|
||||
conn.execute("INSERT INTO other_table VALUES (?)", (value2,))
|
||||
# Auto-commits if no exception
|
||||
"""
|
||||
conn = sqlite3.connect(
|
||||
str(self.db_path),
|
||||
timeout=self.config.timeout,
|
||||
isolation_level="DEFERRED" # Explicit transaction mode
|
||||
)
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
yield conn
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def execute_script(self, sql_script: str) -> None:
|
||||
"""
|
||||
Execute a SQL script (multiple statements).
|
||||
|
||||
Args:
|
||||
sql_script: SQL script containing one or more statements.
|
||||
"""
|
||||
with self.get_connection() as conn:
|
||||
conn.executescript(sql_script)
|
||||
logger.info("Executed SQL script successfully")
|
||||
|
||||
def table_exists(self, table_name: str) -> bool:
|
||||
"""
|
||||
Check if a table exists in the database.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table to check.
|
||||
|
||||
Returns:
|
||||
True if the table exists, False otherwise.
|
||||
"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
||||
(table_name,)
|
||||
)
|
||||
return cursor.fetchone() is not None
|
||||
|
||||
def get_table_count(self, table_name: str) -> int:
|
||||
"""
|
||||
Get the row count for a table.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table.
|
||||
|
||||
Returns:
|
||||
Number of rows in the table.
|
||||
"""
|
||||
with self.get_connection() as conn:
|
||||
# Use parameterized table name via string formatting (safe since we control table_name)
|
||||
cursor = conn.execute(f"SELECT COUNT(*) FROM {table_name}")
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else 0
|
||||
|
||||
|
||||
# Default instance for application-wide use
|
||||
default_db_config = DatabaseConfig()
|
||||
default_db_manager = DatabaseManager(default_db_config)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,246 @@
|
||||
"""
|
||||
Data loader abstractions for NHS High-Cost Drug Patient Pathway Analysis Tool.
|
||||
|
||||
Provides a unified interface for loading patient intervention data from:
|
||||
- CSV/Parquet files (current behavior)
|
||||
- SQLite database (new, faster approach)
|
||||
- Snowflake (future, direct from warehouse)
|
||||
|
||||
The DataLoader ABC defines the contract for all loader implementations.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from core import PathConfig, default_paths
|
||||
from core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadResult:
|
||||
"""Result of a data load operation.
|
||||
|
||||
Attributes:
|
||||
df: The loaded DataFrame with processed patient intervention data
|
||||
source: Description of the data source (e.g., "file:/path/to/file.csv")
|
||||
row_count: Number of rows loaded
|
||||
columns: List of column names in the DataFrame
|
||||
load_time_seconds: Time taken to load the data
|
||||
"""
|
||||
df: pd.DataFrame
|
||||
source: str
|
||||
row_count: int
|
||||
columns: list[str] = field(default_factory=list)
|
||||
load_time_seconds: float = 0.0
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.columns:
|
||||
self.columns = list(self.df.columns)
|
||||
|
||||
|
||||
# Expected columns in a processed DataFrame
|
||||
# These are the columns that generate_graph() expects to receive
|
||||
REQUIRED_COLUMNS = [
|
||||
"UPID", # Unique Patient ID (Provider Code prefix + PersonKey)
|
||||
"Drug Name", # Standardized drug name
|
||||
"Intervention Date", # Date of intervention
|
||||
"Price Actual", # Cost of intervention
|
||||
"OrganisationName", # NHS Trust name
|
||||
"Directory", # Medical specialty/directory
|
||||
"Provider Code", # NHS provider code
|
||||
"PersonKey", # Patient identifier within provider
|
||||
]
|
||||
|
||||
# Additional columns that are useful but not strictly required
|
||||
OPTIONAL_COLUMNS = [
|
||||
"UPIDTreatment", # UPID + Drug Name combo (created by generate_graph)
|
||||
"Treatment Function Code", # NHS treatment function code
|
||||
"Additional Detail 1",
|
||||
"Additional Detail 2",
|
||||
"Additional Detail 3",
|
||||
"Additional Detail 4",
|
||||
"Additional Detail 5",
|
||||
]
|
||||
|
||||
|
||||
class DataLoader(ABC):
|
||||
"""Abstract base class for data loaders.
|
||||
|
||||
All data loaders must implement the load() method which returns
|
||||
a DataFrame ready for use by generate_graph().
|
||||
|
||||
The returned DataFrame must contain REQUIRED_COLUMNS at minimum.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def load(self) -> LoadResult:
|
||||
"""Load and process patient intervention data.
|
||||
|
||||
Returns:
|
||||
LoadResult containing the processed DataFrame and metadata.
|
||||
The DataFrame must contain all REQUIRED_COLUMNS.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the data source doesn't exist
|
||||
ValueError: If the data is malformed or missing required columns
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_source(self) -> tuple[bool, str]:
|
||||
"""Check if the data source is valid and accessible.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, message).
|
||||
If is_valid is False, message explains the issue.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def source_description(self) -> str:
|
||||
"""Human-readable description of the data source."""
|
||||
pass
|
||||
|
||||
def validate_dataframe(self, df: pd.DataFrame) -> tuple[bool, list[str]]:
|
||||
"""Validate that a DataFrame has all required columns.
|
||||
|
||||
Args:
|
||||
df: DataFrame to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, missing_columns).
|
||||
If is_valid is False, missing_columns lists what's missing.
|
||||
"""
|
||||
missing = [col for col in REQUIRED_COLUMNS if col not in df.columns]
|
||||
return len(missing) == 0, missing
|
||||
|
||||
|
||||
class FileDataLoader(DataLoader):
|
||||
"""Loads data from CSV or Parquet files.
|
||||
|
||||
This replicates the current behavior of dashboard_gui.main():
|
||||
1. Read CSV or Parquet file
|
||||
2. Apply patient_id() transformation
|
||||
3. Convert dates
|
||||
4. Apply drug_names() standardization
|
||||
5. Clean organization names
|
||||
6. Apply department_identification()
|
||||
|
||||
Args:
|
||||
file_path: Path to the CSV or Parquet file
|
||||
paths: PathConfig for reference data file locations (uses default_paths if None)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: Path | str,
|
||||
paths: Optional[PathConfig] = None,
|
||||
):
|
||||
self.file_path = Path(file_path)
|
||||
self.paths = paths or default_paths
|
||||
|
||||
def validate_source(self) -> tuple[bool, str]:
|
||||
"""Check if the file exists and has a supported extension."""
|
||||
if not self.file_path.exists():
|
||||
return False, f"File not found: {self.file_path}"
|
||||
|
||||
ext = self.file_path.suffix.lower()
|
||||
if ext not in ('.csv', '.parquet'):
|
||||
return False, f"Unsupported file type: {ext}. Must be .csv or .parquet"
|
||||
|
||||
return True, "OK"
|
||||
|
||||
@property
|
||||
def source_description(self) -> str:
|
||||
return f"file:{self.file_path}"
|
||||
|
||||
def load(self) -> LoadResult:
|
||||
"""Load and process data from CSV or Parquet file.
|
||||
|
||||
Applies the same transformation pipeline as the original
|
||||
dashboard_gui.main() function.
|
||||
"""
|
||||
import time
|
||||
from data_processing import transforms as data
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Validate source before loading
|
||||
is_valid, msg = self.validate_source()
|
||||
if not is_valid:
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
# Read file based on extension
|
||||
ext = self.file_path.suffix.lower()
|
||||
logger.info(f"Reading {ext} file: {self.file_path}")
|
||||
|
||||
if ext == '.csv':
|
||||
df_raw = pd.read_csv(self.file_path, low_memory=False)
|
||||
else: # .parquet
|
||||
df_raw = pd.read_parquet(self.file_path)
|
||||
|
||||
logger.info(f"File read successfully. {len(df_raw)} rows.")
|
||||
|
||||
# Apply transformations (same as dashboard_gui.main())
|
||||
df = data.patient_id(df_raw)
|
||||
logger.info("Patient ID processing complete.")
|
||||
|
||||
df['Intervention Date'] = pd.to_datetime(df['Intervention Date'], format="%Y-%m-%d")
|
||||
logger.info("Date conversion complete.")
|
||||
|
||||
# Preserve original drug name before standardization (for SQLite storage)
|
||||
df['Drug Name Raw'] = df['Drug Name'].copy()
|
||||
|
||||
df = data.drug_names(df, self.paths)
|
||||
logger.info("Drug name processing complete.")
|
||||
|
||||
df['OrganisationName'] = df['OrganisationName'].str.replace(',', '')
|
||||
logger.info("Organisation name cleaning complete.")
|
||||
|
||||
df = data.department_identification(df, self.paths)
|
||||
logger.info("Department identification complete.")
|
||||
|
||||
# Validate result
|
||||
is_valid, missing = self.validate_dataframe(df)
|
||||
if not is_valid:
|
||||
raise ValueError(f"Processed DataFrame missing required columns: {missing}")
|
||||
|
||||
load_time = time.time() - start_time
|
||||
logger.info(f"Data loading complete. {len(df)} rows in {load_time:.2f}s")
|
||||
|
||||
return LoadResult(
|
||||
df=df,
|
||||
source=self.source_description,
|
||||
row_count=len(df),
|
||||
load_time_seconds=load_time,
|
||||
)
|
||||
|
||||
|
||||
def get_loader(
|
||||
source: str | Path,
|
||||
paths: Optional[PathConfig] = None,
|
||||
**kwargs
|
||||
) -> DataLoader:
|
||||
"""Factory function to create the appropriate DataLoader.
|
||||
|
||||
Args:
|
||||
source: File path (CSV/Parquet)
|
||||
paths: PathConfig for reference data (used by FileDataLoader)
|
||||
**kwargs: Additional arguments passed to the loader constructor
|
||||
|
||||
Returns:
|
||||
Appropriate DataLoader instance
|
||||
|
||||
Examples:
|
||||
>>> loader = get_loader("data/activity.csv")
|
||||
>>> loader = get_loader("data/activity.parquet")
|
||||
"""
|
||||
path = Path(source)
|
||||
return FileDataLoader(file_path=path, paths=paths)
|
||||
@@ -0,0 +1,469 @@
|
||||
"""
|
||||
Database migration script for NHS High-Cost Drug Patient Pathway Analysis Tool.
|
||||
|
||||
Provides functions to initialize the SQLite database schema and CLI interface
|
||||
for running migrations from the command line.
|
||||
|
||||
Usage:
|
||||
# Initialize database (creates all tables)
|
||||
python -m data_processing.migrate
|
||||
|
||||
# Drop existing tables and reinitialize
|
||||
python -m data_processing.migrate --drop-existing
|
||||
|
||||
# Show current database status
|
||||
python -m data_processing.migrate --status
|
||||
|
||||
# Migrate all reference data from CSV files
|
||||
python -m data_processing.migrate --reference-data
|
||||
|
||||
# Migrate reference data with verification
|
||||
python -m data_processing.migrate --reference-data --verify
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
# Ensure src/ is on sys.path when run as `python -m data_processing.migrate`
|
||||
_src_dir = str(Path(__file__).resolve().parent.parent)
|
||||
if _src_dir not in sys.path:
|
||||
sys.path.insert(0, _src_dir)
|
||||
|
||||
from core.logging_config import setup_logging, get_logger
|
||||
from data_processing.database import DatabaseManager, DatabaseConfig
|
||||
from core import PathConfig, default_paths
|
||||
from data_processing.schema import (
|
||||
create_all_tables,
|
||||
drop_all_tables,
|
||||
verify_all_tables_exist,
|
||||
get_all_table_counts,
|
||||
migrate_pathway_nodes_chart_type,
|
||||
migrate_refresh_log_source_row_count,
|
||||
)
|
||||
from data_processing.reference_data import (
|
||||
MigrationResult,
|
||||
migrate_drug_names,
|
||||
migrate_organizations,
|
||||
migrate_directories,
|
||||
migrate_drug_directory_map,
|
||||
migrate_drug_indication_clusters,
|
||||
verify_drug_names_migration,
|
||||
verify_organizations_migration,
|
||||
verify_directories_migration,
|
||||
verify_drug_directory_map_migration,
|
||||
verify_drug_indication_clusters_migration,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def initialize_database(
|
||||
db_manager: Optional[DatabaseManager] = None,
|
||||
drop_existing: bool = False,
|
||||
confirm_drop: bool = True
|
||||
) -> bool:
|
||||
"""
|
||||
Initialize the database with all required tables.
|
||||
|
||||
Creates all tables defined in the schema (reference tables and pathway
|
||||
tables). Uses IF NOT EXISTS so safe to run multiple times.
|
||||
|
||||
Args:
|
||||
db_manager: DatabaseManager instance. Uses default if not provided.
|
||||
drop_existing: If True, drops all existing tables before creating.
|
||||
confirm_drop: If True and drop_existing=True, prompts for confirmation.
|
||||
Set to False for non-interactive use.
|
||||
|
||||
Returns:
|
||||
True if initialization succeeded, False otherwise.
|
||||
"""
|
||||
if db_manager is None:
|
||||
db_manager = DatabaseManager()
|
||||
|
||||
logger.info(f"Initializing database at: {db_manager.db_path}")
|
||||
|
||||
# Handle drop existing with confirmation
|
||||
if drop_existing:
|
||||
if confirm_drop:
|
||||
print(f"\nWARNING: This will delete ALL data from the database:")
|
||||
print(f" {db_manager.db_path}\n")
|
||||
response = input("Are you sure you want to continue? (yes/no): ")
|
||||
if response.lower() not in ("yes", "y"):
|
||||
print("Operation cancelled.")
|
||||
return False
|
||||
|
||||
if db_manager.exists:
|
||||
logger.warning("Dropping existing tables...")
|
||||
with db_manager.get_connection() as conn:
|
||||
drop_all_tables(conn)
|
||||
conn.commit()
|
||||
logger.info("Existing tables dropped")
|
||||
else:
|
||||
logger.info("Database does not exist yet, nothing to drop")
|
||||
|
||||
# Create all tables
|
||||
try:
|
||||
with db_manager.get_transaction() as conn:
|
||||
create_all_tables(conn)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create tables: {e}")
|
||||
return False
|
||||
|
||||
# Run migrations for schema changes
|
||||
try:
|
||||
with db_manager.get_connection() as conn:
|
||||
# Add chart_type column to pathway_nodes if it doesn't exist
|
||||
success, msg = migrate_pathway_nodes_chart_type(conn)
|
||||
if success:
|
||||
logger.info(f"pathway_nodes migration: {msg}")
|
||||
else:
|
||||
logger.error(f"pathway_nodes migration failed: {msg}")
|
||||
return False
|
||||
|
||||
# Add source_row_count column to pathway_refresh_log if it doesn't exist
|
||||
success, msg = migrate_refresh_log_source_row_count(conn)
|
||||
if success:
|
||||
logger.info(f"pathway_refresh_log migration: {msg}")
|
||||
else:
|
||||
logger.error(f"pathway_refresh_log migration failed: {msg}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Migration failed: {e}")
|
||||
return False
|
||||
|
||||
# Verify all tables were created
|
||||
with db_manager.get_connection() as conn:
|
||||
missing = verify_all_tables_exist(conn)
|
||||
|
||||
if missing:
|
||||
logger.error(f"Table creation failed. Missing tables: {missing}")
|
||||
return False
|
||||
|
||||
logger.info("All tables created successfully")
|
||||
return True
|
||||
|
||||
|
||||
def migrate_all_reference_data(
|
||||
db_manager: Optional[DatabaseManager] = None,
|
||||
paths: Optional[PathConfig] = None,
|
||||
verify: bool = False
|
||||
) -> tuple[bool, list[MigrationResult]]:
|
||||
"""
|
||||
Run all reference data migrations from CSV files to SQLite tables.
|
||||
|
||||
Migrations are run in order:
|
||||
1. Drug names (drugnames.csv → ref_drug_names)
|
||||
2. Organizations (org_codes.csv → ref_organizations)
|
||||
3. Directories (directory_list.csv → ref_directories)
|
||||
4. Drug-directory mappings (drug_directory_list.csv → ref_drug_directory_map)
|
||||
|
||||
Args:
|
||||
db_manager: DatabaseManager instance. Uses default if not provided.
|
||||
paths: PathConfig instance for locating CSV files. Uses default if not provided.
|
||||
verify: If True, runs verification after each migration.
|
||||
|
||||
Returns:
|
||||
Tuple of (all_success: bool, results: list of MigrationResult)
|
||||
"""
|
||||
if db_manager is None:
|
||||
db_manager = DatabaseManager()
|
||||
if paths is None:
|
||||
paths = default_paths
|
||||
|
||||
results: list[MigrationResult] = []
|
||||
all_success = True
|
||||
|
||||
# Define migrations in order
|
||||
# Note: drug_indication_clusters uses a different signature (csv_path instead of paths)
|
||||
migrations = [
|
||||
("Drug names", migrate_drug_names, verify_drug_names_migration if verify else None, True),
|
||||
("Organizations", migrate_organizations, verify_organizations_migration if verify else None, True),
|
||||
("Directories", migrate_directories, verify_directories_migration if verify else None, True),
|
||||
("Drug-directory map", migrate_drug_directory_map, verify_drug_directory_map_migration if verify else None, True),
|
||||
("Drug indication clusters", migrate_drug_indication_clusters, verify_drug_indication_clusters_migration if verify else None, False),
|
||||
]
|
||||
|
||||
logger.info(f"Starting reference data migrations ({len(migrations)} tables)")
|
||||
|
||||
for name, migrate_fn, verify_fn, uses_paths in migrations:
|
||||
logger.info(f"Migrating: {name}...")
|
||||
|
||||
# Run migration (some use paths parameter, some use csv_path)
|
||||
if uses_paths:
|
||||
result = migrate_fn(db_manager=db_manager, paths=paths) # type: ignore[operator]
|
||||
else:
|
||||
# Drug indication clusters uses csv_path instead of paths
|
||||
result = migrate_fn(db_manager=db_manager) # type: ignore[operator]
|
||||
results.append(result)
|
||||
|
||||
if not result.success:
|
||||
logger.error(f"Migration failed: {name} - {result.error_message}")
|
||||
all_success = False
|
||||
continue
|
||||
|
||||
logger.info(f" {result}")
|
||||
|
||||
# Run verification if requested
|
||||
if verify_fn is not None:
|
||||
logger.info(f" Verifying {name}...")
|
||||
if uses_paths:
|
||||
verified, verify_msg = verify_fn(db_manager=db_manager, paths=paths) # type: ignore[call-arg]
|
||||
else:
|
||||
verified, verify_msg = verify_fn(db_manager=db_manager) # type: ignore[call-arg]
|
||||
if verified:
|
||||
logger.info(f" OK: {verify_msg}")
|
||||
else:
|
||||
logger.error(f" FAILED: Verification failed: {verify_msg}")
|
||||
all_success = False
|
||||
|
||||
# Summary
|
||||
successful = sum(1 for r in results if r.success)
|
||||
logger.info(f"Reference data migrations complete: {successful}/{len(results)} succeeded")
|
||||
|
||||
return all_success, results
|
||||
|
||||
|
||||
def print_migration_summary(results: list[MigrationResult]) -> None:
|
||||
"""Print a summary of migration results to stdout."""
|
||||
print("\n=== Reference Data Migration Summary ===\n")
|
||||
|
||||
for result in results:
|
||||
status = "[OK]" if result.success else "[FAILED]"
|
||||
print(f"{status} {result.table_name}")
|
||||
if result.success:
|
||||
print(f" Read: {result.rows_read}, Inserted: {result.rows_inserted}, Skipped: {result.rows_skipped}")
|
||||
else:
|
||||
print(f" Error: {result.error_message}")
|
||||
|
||||
successful = sum(1 for r in results if r.success)
|
||||
print(f"\nTotal: {successful}/{len(results)} migrations succeeded")
|
||||
print()
|
||||
|
||||
|
||||
def create_progress_reporter(description: str = "Loading", width: int = 40):
|
||||
"""
|
||||
Create a progress callback that prints a progress bar to stdout.
|
||||
|
||||
Args:
|
||||
description: Label to show before the progress bar.
|
||||
width: Width of the progress bar in characters.
|
||||
|
||||
Returns:
|
||||
Callback function(current, total) that prints progress.
|
||||
"""
|
||||
last_percent = [-1] # Use list to allow mutation in closure
|
||||
|
||||
def report_progress(current: int, total: int) -> None:
|
||||
"""Print a progress bar showing current/total progress."""
|
||||
if total == 0:
|
||||
percent = 100
|
||||
else:
|
||||
percent = int(100 * current / total)
|
||||
|
||||
# Only update display when percentage changes (avoid excessive output)
|
||||
if percent == last_percent[0]:
|
||||
return
|
||||
last_percent[0] = percent
|
||||
|
||||
filled = int(width * current / total) if total > 0 else width
|
||||
bar = "=" * filled + "-" * (width - filled)
|
||||
|
||||
# Use carriage return to overwrite the line
|
||||
sys.stdout.write(f"\r{description}: [{bar}] {percent:3d}% ({current:,}/{total:,})")
|
||||
sys.stdout.flush()
|
||||
|
||||
# Print newline when complete
|
||||
if current >= total:
|
||||
print()
|
||||
|
||||
return report_progress
|
||||
|
||||
|
||||
def get_database_status(db_manager: Optional[DatabaseManager] = None) -> dict:
|
||||
"""
|
||||
Get the current status of the database.
|
||||
|
||||
Returns:
|
||||
Dictionary with database status information:
|
||||
- exists: Whether the database file exists
|
||||
- path: Path to the database file
|
||||
- size_bytes: Size of database file (if exists)
|
||||
- tables: Dictionary of table names to row counts
|
||||
- missing_tables: List of expected tables that don't exist
|
||||
"""
|
||||
if db_manager is None:
|
||||
db_manager = DatabaseManager()
|
||||
|
||||
status = {
|
||||
"exists": db_manager.exists,
|
||||
"path": str(db_manager.db_path),
|
||||
"size_bytes": None,
|
||||
"tables": {},
|
||||
"missing_tables": [],
|
||||
}
|
||||
|
||||
if db_manager.exists:
|
||||
status["size_bytes"] = db_manager.db_path.stat().st_size
|
||||
|
||||
with db_manager.get_connection() as conn:
|
||||
status["missing_tables"] = verify_all_tables_exist(conn)
|
||||
|
||||
# Get counts for existing tables
|
||||
try:
|
||||
status["tables"] = get_all_table_counts(conn)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not get table counts: {e}")
|
||||
|
||||
return status
|
||||
|
||||
|
||||
def print_database_status(db_manager: Optional[DatabaseManager] = None) -> None:
|
||||
"""Print database status to stdout in a human-readable format."""
|
||||
status = get_database_status(db_manager)
|
||||
|
||||
print("\n=== Database Status ===\n")
|
||||
print(f"Path: {status['path']}")
|
||||
print(f"Exists: {status['exists']}")
|
||||
|
||||
if status["exists"]:
|
||||
size_kb = (status["size_bytes"] or 0) / 1024
|
||||
print(f"Size: {size_kb:.1f} KB")
|
||||
|
||||
if status["missing_tables"]:
|
||||
print(f"\nMissing tables: {', '.join(status['missing_tables'])}")
|
||||
else:
|
||||
print("\nAll expected tables exist.")
|
||||
|
||||
if status["tables"]:
|
||||
print("\nTable row counts:")
|
||||
for table, count in sorted(status["tables"].items()):
|
||||
print(f" {table}: {count:,} rows")
|
||||
else:
|
||||
print("\nDatabase does not exist. Run migration to create it.")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def main():
|
||||
"""CLI entry point for database migration."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Initialize NHS Pathways Analysis SQLite database schema",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python -m data_processing.migrate # Initialize database
|
||||
python -m data_processing.migrate --status # Show database status
|
||||
python -m data_processing.migrate --drop-existing # Reset database
|
||||
python -m data_processing.migrate --reference-data # Migrate reference data
|
||||
python -m data_processing.migrate --reference-data --verify # With verification
|
||||
python -m data_processing.migrate --db-path ./data/test.db # Custom path
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--status",
|
||||
action="store_true",
|
||||
help="Show current database status and exit"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--drop-existing",
|
||||
action="store_true",
|
||||
help="Drop all existing tables before creating (WARNING: deletes data)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reference-data",
|
||||
action="store_true",
|
||||
help="Migrate all reference data from CSV files to SQLite tables"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verify",
|
||||
action="store_true",
|
||||
help="Verify migrated data matches CSV sources (use with --reference-data)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--db-path",
|
||||
type=Path,
|
||||
help="Path to database file (default: ./data/pathways.db)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--yes", "-y",
|
||||
action="store_true",
|
||||
help="Skip confirmation prompts (for non-interactive use)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", "-v",
|
||||
action="store_true",
|
||||
help="Enable verbose logging"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set up logging
|
||||
log_level = "DEBUG" if args.verbose else "INFO"
|
||||
setup_logging(level=log_level, simple_console=True)
|
||||
|
||||
# Create database manager with optional custom path
|
||||
if args.db_path:
|
||||
config = DatabaseConfig(db_path=args.db_path)
|
||||
db_manager = DatabaseManager(config)
|
||||
else:
|
||||
db_manager = DatabaseManager()
|
||||
|
||||
# Handle --status
|
||||
if args.status:
|
||||
print_database_status(db_manager)
|
||||
return 0
|
||||
|
||||
# Validate configuration
|
||||
config_errors = db_manager.config.validate()
|
||||
if config_errors:
|
||||
for error in config_errors:
|
||||
logger.error(error)
|
||||
return 1
|
||||
|
||||
# Handle --reference-data (migrate reference data from CSV to SQLite)
|
||||
if args.reference_data:
|
||||
# Ensure database exists with tables first
|
||||
if not db_manager.exists:
|
||||
print("Database does not exist. Initializing schema first...")
|
||||
success = initialize_database(db_manager=db_manager)
|
||||
if not success:
|
||||
print("\nDatabase initialization failed. Check logs for details.")
|
||||
return 1
|
||||
|
||||
# Run reference data migrations
|
||||
success, results = migrate_all_reference_data(
|
||||
db_manager=db_manager,
|
||||
paths=default_paths,
|
||||
verify=args.verify
|
||||
)
|
||||
|
||||
print_migration_summary(results)
|
||||
print_database_status(db_manager)
|
||||
|
||||
if success:
|
||||
print("Reference data migration completed successfully.")
|
||||
return 0
|
||||
else:
|
||||
print("Reference data migration completed with errors. Check logs for details.")
|
||||
return 1
|
||||
|
||||
# Run schema migration (default behavior)
|
||||
success = initialize_database(
|
||||
db_manager=db_manager,
|
||||
drop_existing=args.drop_existing,
|
||||
confirm_drop=not args.yes
|
||||
)
|
||||
|
||||
if success:
|
||||
print("\nDatabase initialized successfully.")
|
||||
print_database_status(db_manager)
|
||||
return 0
|
||||
else:
|
||||
print("\nDatabase initialization failed. Check logs for details.")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -0,0 +1,642 @@
|
||||
"""
|
||||
Pathway data processing pipeline.
|
||||
|
||||
This module provides functions to:
|
||||
1. Fetch and transform raw intervention data from Snowflake
|
||||
2. Process data for each of the 6 date filter combinations
|
||||
3. Extract denormalized fields from hierarchical path strings
|
||||
4. Convert processed data to records for SQLite storage
|
||||
|
||||
The pipeline integrates with:
|
||||
- analysis/pathway_analyzer.py: generate_icicle_chart() for pathway processing
|
||||
- data_processing/snowflake_connector.py: fetch_activity_data() for data retrieval
|
||||
- tools/data.py: patient_id(), drug_names(), department_identification()
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, timedelta
|
||||
from typing import Optional, Literal
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from core import PathConfig, default_paths
|
||||
from core.logging_config import get_logger
|
||||
from analysis.pathway_analyzer import generate_icicle_chart, generate_icicle_chart_indication
|
||||
from data_processing.transforms import patient_id, drug_names, department_identification
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Type alias for chart types
|
||||
ChartType = Literal["directory", "indication"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DateFilterConfig:
|
||||
"""Configuration for a date filter combination."""
|
||||
|
||||
id: str # e.g., 'all_6mo', '1yr_12mo'
|
||||
initiated_years: Optional[int] # None for 'All', 1, or 2
|
||||
last_seen_months: int # 6 or 12
|
||||
|
||||
|
||||
# Pre-defined date filter configurations matching pathway_date_filters table
|
||||
DATE_FILTER_CONFIGS = [
|
||||
DateFilterConfig(id="all_6mo", initiated_years=None, last_seen_months=6),
|
||||
DateFilterConfig(id="all_12mo", initiated_years=None, last_seen_months=12),
|
||||
DateFilterConfig(id="1yr_6mo", initiated_years=1, last_seen_months=6),
|
||||
DateFilterConfig(id="1yr_12mo", initiated_years=1, last_seen_months=12),
|
||||
DateFilterConfig(id="2yr_6mo", initiated_years=2, last_seen_months=6),
|
||||
DateFilterConfig(id="2yr_12mo", initiated_years=2, last_seen_months=12),
|
||||
]
|
||||
|
||||
|
||||
def compute_date_ranges(
|
||||
config: DateFilterConfig,
|
||||
max_date: Optional[date] = None,
|
||||
) -> tuple[str, str, str]:
|
||||
"""
|
||||
Compute actual date strings from a date filter configuration.
|
||||
|
||||
Args:
|
||||
config: DateFilterConfig with initiated_years and last_seen_months
|
||||
max_date: Reference date (defaults to today)
|
||||
|
||||
Returns:
|
||||
Tuple of (start_date, end_date, last_seen_date) as ISO format strings
|
||||
- start_date: Start of initiated filter period
|
||||
- end_date: End of initiated filter period (usually max_date)
|
||||
- last_seen_date: Date threshold for last_seen filter
|
||||
"""
|
||||
if max_date is None:
|
||||
max_date = date.today()
|
||||
|
||||
# Calculate end_date (always max_date)
|
||||
end_date = max_date
|
||||
|
||||
# Calculate start_date based on initiated_years
|
||||
if config.initiated_years is None:
|
||||
# "All years" - use a very old date
|
||||
start_date = date(2000, 1, 1)
|
||||
else:
|
||||
# Last N years from max_date
|
||||
start_date = max_date.replace(year=max_date.year - config.initiated_years)
|
||||
|
||||
# Calculate last_seen_date based on last_seen_months
|
||||
# Patients must have been seen within the last N months
|
||||
last_seen_date = max_date - timedelta(days=config.last_seen_months * 30)
|
||||
|
||||
return (
|
||||
start_date.isoformat(),
|
||||
end_date.isoformat(),
|
||||
last_seen_date.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
def fetch_and_transform_data(
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
provider_codes: Optional[list[str]] = None,
|
||||
paths: Optional[PathConfig] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Fetch data from Snowflake and apply standard transformations.
|
||||
|
||||
This function:
|
||||
1. Fetches raw intervention data from Snowflake
|
||||
2. Applies UPID generation (Provider Code[:3] + PersonKey)
|
||||
3. Standardizes drug names via drugnames.csv mapping
|
||||
4. Assigns directories using the 5-level fallback logic
|
||||
|
||||
Args:
|
||||
start_date: Optional start date filter for Snowflake query
|
||||
end_date: Optional end date filter for Snowflake query
|
||||
provider_codes: Optional list of provider codes to filter
|
||||
paths: PathConfig for file paths (uses default if None)
|
||||
|
||||
Returns:
|
||||
DataFrame with columns: UPID, Drug Name, Directory, Intervention Date,
|
||||
Price Actual, Provider Code, PersonKey, OrganisationName, etc.
|
||||
|
||||
Raises:
|
||||
ImportError: If snowflake-connector-python is not installed
|
||||
SnowflakeConnectionError: If connection fails
|
||||
"""
|
||||
if paths is None:
|
||||
paths = default_paths
|
||||
|
||||
# Import here to avoid circular imports and handle optional dependency
|
||||
from data_processing.snowflake_connector import get_connector, is_snowflake_available
|
||||
|
||||
if not is_snowflake_available():
|
||||
raise ImportError(
|
||||
"snowflake-connector-python is not installed. "
|
||||
"Install it with: pip install snowflake-connector-python"
|
||||
)
|
||||
|
||||
logger.info("Fetching activity data from Snowflake...")
|
||||
|
||||
connector = get_connector()
|
||||
raw_data = connector.fetch_activity_data(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
provider_codes=provider_codes,
|
||||
max_rows=0, # No limit
|
||||
)
|
||||
|
||||
if not raw_data:
|
||||
logger.warning("No data returned from Snowflake")
|
||||
return pd.DataFrame()
|
||||
|
||||
logger.info(f"Fetched {len(raw_data)} records from Snowflake")
|
||||
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(raw_data)
|
||||
|
||||
# Apply transformations in the standard order
|
||||
logger.info("Applying data transformations...")
|
||||
|
||||
# 1. Generate UPID
|
||||
df = patient_id(df)
|
||||
logger.info(f"Generated UPID for {df['UPID'].nunique()} unique patients")
|
||||
|
||||
# 2. Standardize drug names
|
||||
df = drug_names(df, paths)
|
||||
# Remove rows where drug name mapping failed (NaN)
|
||||
before_count = len(df)
|
||||
df = df.dropna(subset=['Drug Name'])
|
||||
after_count = len(df)
|
||||
if before_count != after_count:
|
||||
logger.info(f"Removed {before_count - after_count} rows with unmapped drug names")
|
||||
|
||||
# 3. Assign directories
|
||||
df = department_identification(df, paths)
|
||||
logger.info(f"Assigned directories to {len(df)} records")
|
||||
|
||||
# Ensure Intervention Date is datetime
|
||||
df['Intervention Date'] = pd.to_datetime(df['Intervention Date'])
|
||||
|
||||
logger.info(f"Data transformation complete. Final record count: {len(df)}")
|
||||
return df
|
||||
|
||||
|
||||
def process_pathway_for_date_filter(
|
||||
df: pd.DataFrame,
|
||||
config: DateFilterConfig,
|
||||
trust_filter: list[str],
|
||||
drug_filter: list[str],
|
||||
directory_filter: list[str],
|
||||
minimum_patients: int = 5,
|
||||
max_date: Optional[date] = None,
|
||||
paths: Optional[PathConfig] = None,
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
Process pathway data for a single date filter configuration.
|
||||
|
||||
Uses the existing generate_icicle_chart() function from pathway_analyzer.py
|
||||
to build the pathway hierarchy with treatment statistics.
|
||||
|
||||
Args:
|
||||
df: Transformed DataFrame from fetch_and_transform_data()
|
||||
config: DateFilterConfig specifying the date filter combination
|
||||
trust_filter: List of trust names to include
|
||||
drug_filter: List of drug names to include
|
||||
directory_filter: List of directories to include
|
||||
minimum_patients: Minimum patients to include a pathway
|
||||
max_date: Reference date for computing date ranges
|
||||
paths: PathConfig for file paths
|
||||
|
||||
Returns:
|
||||
DataFrame with pathway hierarchy (ice_df) or None if no data
|
||||
"""
|
||||
if paths is None:
|
||||
paths = default_paths
|
||||
|
||||
# Compute actual date ranges for this filter config
|
||||
start_date, end_date, last_seen_date = compute_date_ranges(config, max_date)
|
||||
|
||||
logger.info(f"Processing pathway for {config.id}")
|
||||
logger.info(f" Date range: {start_date} to {end_date}")
|
||||
logger.info(f" Last seen after: {last_seen_date}")
|
||||
|
||||
# Use the existing pathway analyzer
|
||||
ice_df, title = generate_icicle_chart(
|
||||
df=df,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
last_seen_date=last_seen_date,
|
||||
trust_filter=trust_filter,
|
||||
drug_filter=drug_filter,
|
||||
directory_filter=directory_filter,
|
||||
minimum_num_patients=minimum_patients,
|
||||
title="",
|
||||
paths=paths,
|
||||
)
|
||||
|
||||
if ice_df is None or len(ice_df) == 0:
|
||||
logger.warning(f"No pathway data for filter {config.id}")
|
||||
return None
|
||||
|
||||
logger.info(f"Generated {len(ice_df)} pathway nodes for {config.id}")
|
||||
return ice_df
|
||||
|
||||
|
||||
def process_indication_pathway_for_date_filter(
|
||||
df: pd.DataFrame,
|
||||
indication_df: pd.DataFrame,
|
||||
config: DateFilterConfig,
|
||||
trust_filter: list[str],
|
||||
drug_filter: list[str],
|
||||
directory_filter: list[str],
|
||||
minimum_patients: int = 5,
|
||||
max_date: Optional[date] = None,
|
||||
paths: Optional[PathConfig] = None,
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
Process indication-based pathway data for a single date filter configuration.
|
||||
|
||||
This is similar to process_pathway_for_date_filter() but uses indication-based
|
||||
grouping (Search_Term from GP diagnosis) instead of directory grouping.
|
||||
|
||||
Hierarchy: Trust → Indication_Group → Drug → Pathway
|
||||
|
||||
Args:
|
||||
df: Transformed DataFrame from fetch_and_transform_data()
|
||||
indication_df: DataFrame with UPID → Indication_Group mapping
|
||||
Must have columns: UPID, Indication_Group
|
||||
Indication_Group is either Search_Term or "Directory (no GP dx)"
|
||||
config: DateFilterConfig specifying the date filter combination
|
||||
trust_filter: List of trust names to include
|
||||
drug_filter: List of drug names to include
|
||||
directory_filter: List of directories to include
|
||||
minimum_patients: Minimum patients to include a pathway
|
||||
max_date: Reference date for computing date ranges
|
||||
paths: PathConfig for file paths
|
||||
|
||||
Returns:
|
||||
DataFrame with pathway hierarchy (ice_df) or None if no data
|
||||
"""
|
||||
if paths is None:
|
||||
paths = default_paths
|
||||
|
||||
# Compute actual date ranges for this filter config
|
||||
start_date, end_date, last_seen_date = compute_date_ranges(config, max_date)
|
||||
|
||||
logger.info(f"Processing indication pathway for {config.id}")
|
||||
logger.info(f" Date range: {start_date} to {end_date}")
|
||||
logger.info(f" Last seen after: {last_seen_date}")
|
||||
|
||||
# Use the indication-aware pathway analyzer
|
||||
ice_df, title = generate_icicle_chart_indication(
|
||||
df=df,
|
||||
indication_df=indication_df,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
last_seen_date=last_seen_date,
|
||||
trust_filter=trust_filter,
|
||||
drug_filter=drug_filter,
|
||||
directory_filter=directory_filter,
|
||||
minimum_num_patients=minimum_patients,
|
||||
title="",
|
||||
paths=paths,
|
||||
)
|
||||
|
||||
if ice_df is None or len(ice_df) == 0:
|
||||
logger.warning(f"No indication pathway data for filter {config.id}")
|
||||
return None
|
||||
|
||||
logger.info(f"Generated {len(ice_df)} indication pathway nodes for {config.id}")
|
||||
return ice_df
|
||||
|
||||
|
||||
def extract_denormalized_fields(ice_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Extract denormalized filter columns from the ids column.
|
||||
|
||||
The ids column contains hierarchical paths like:
|
||||
- "N&WICS" (root)
|
||||
- "N&WICS - NNUH" (trust level)
|
||||
- "N&WICS - NNUH - OPHTHALMOLOGY" (directory level)
|
||||
- "N&WICS - NNUH - OPHTHALMOLOGY - RANIBIZUMAB" (first drug)
|
||||
- "N&WICS - NNUH - OPHTHALMOLOGY - RANIBIZUMAB - AFLIBERCEPT" (pathway)
|
||||
|
||||
This function extracts:
|
||||
- trust_name: The trust component (level 1)
|
||||
- directory: The directory component (level 2)
|
||||
- drug_sequence: Pipe-separated drugs (level 3+)
|
||||
|
||||
Args:
|
||||
ice_df: DataFrame from generate_icicle_chart()
|
||||
|
||||
Returns:
|
||||
DataFrame with added columns: trust_name, directory, drug_sequence
|
||||
"""
|
||||
df = ice_df.copy()
|
||||
|
||||
# Split ids by " - " delimiter
|
||||
def extract_components(ids_str: str) -> tuple[str, str, str]:
|
||||
"""Extract trust, directory, and drug sequence from ids string."""
|
||||
if not ids_str or pd.isna(ids_str):
|
||||
return ("", "", "")
|
||||
|
||||
parts = ids_str.split(" - ")
|
||||
|
||||
# Level 0: Root (e.g., "N&WICS")
|
||||
if len(parts) <= 1:
|
||||
return ("", "", "")
|
||||
|
||||
# Level 1+: Trust is always parts[1]
|
||||
trust_name = parts[1] if len(parts) > 1 else ""
|
||||
|
||||
# Level 2+: Directory is parts[2]
|
||||
directory = parts[2] if len(parts) > 2 else ""
|
||||
|
||||
# Level 3+: Drugs are parts[3:]
|
||||
drugs = parts[3:] if len(parts) > 3 else []
|
||||
drug_sequence = "|".join(drugs) if drugs else ""
|
||||
|
||||
return (trust_name, directory, drug_sequence)
|
||||
|
||||
# Apply extraction to all rows
|
||||
extracted = df['ids'].apply(extract_components)
|
||||
df['trust_name'] = extracted.apply(lambda x: x[0])
|
||||
df['directory'] = extracted.apply(lambda x: x[1])
|
||||
df['drug_sequence'] = extracted.apply(lambda x: x[2])
|
||||
|
||||
logger.info(f"Extracted denormalized fields for {len(df)} nodes")
|
||||
logger.info(f" Unique trusts: {df['trust_name'].nunique()}")
|
||||
logger.info(f" Unique directories: {df['directory'].nunique()}")
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def extract_indication_fields(ice_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Extract denormalized filter columns from the ids column for indication charts.
|
||||
|
||||
Similar to extract_denormalized_fields() but for indication-based charts where
|
||||
the level-2 grouping is Search_Term (or fallback directorate) instead of Directory.
|
||||
|
||||
The ids column contains hierarchical paths like:
|
||||
- "N&WICS" (root)
|
||||
- "N&WICS - NNUH" (trust level)
|
||||
- "N&WICS - NNUH - rheumatoid arthritis" (search_term level - matched patient)
|
||||
- "N&WICS - NNUH - RHEUMATOLOGY (no GP dx)" (fallback level - unmatched patient)
|
||||
- "N&WICS - NNUH - rheumatoid arthritis - ADALIMUMAB" (first drug)
|
||||
- "N&WICS - NNUH - rheumatoid arthritis - ADALIMUMAB - ETANERCEPT" (pathway)
|
||||
|
||||
This function extracts:
|
||||
- trust_name: The trust component (level 1)
|
||||
- search_term: The Search_Term or fallback directorate (level 2)
|
||||
- drug_sequence: Pipe-separated drugs (level 3+)
|
||||
|
||||
Note: For indication charts, 'directory' column contains the search_term
|
||||
to maintain schema compatibility with the pathway_nodes table.
|
||||
|
||||
Args:
|
||||
ice_df: DataFrame from generate_icicle_chart() with indication grouping
|
||||
|
||||
Returns:
|
||||
DataFrame with added columns: trust_name, directory (=search_term), drug_sequence
|
||||
"""
|
||||
df = ice_df.copy()
|
||||
|
||||
def extract_components(ids_str: str) -> tuple[str, str, str]:
|
||||
"""Extract trust, search_term, and drug sequence from ids string."""
|
||||
if not ids_str or pd.isna(ids_str):
|
||||
return ("", "", "")
|
||||
|
||||
parts = ids_str.split(" - ")
|
||||
|
||||
# Level 0: Root (e.g., "N&WICS")
|
||||
if len(parts) <= 1:
|
||||
return ("", "", "")
|
||||
|
||||
# Level 1+: Trust is always parts[1]
|
||||
trust_name = parts[1] if len(parts) > 1 else ""
|
||||
|
||||
# Level 2+: Search_term (or fallback) is parts[2]
|
||||
search_term = parts[2] if len(parts) > 2 else ""
|
||||
|
||||
# Level 3+: Drugs are parts[3:]
|
||||
drugs = parts[3:] if len(parts) > 3 else []
|
||||
drug_sequence = "|".join(drugs) if drugs else ""
|
||||
|
||||
return (trust_name, search_term, drug_sequence)
|
||||
|
||||
# Apply extraction to all rows
|
||||
extracted = df['ids'].apply(extract_components)
|
||||
df['trust_name'] = extracted.apply(lambda x: x[0])
|
||||
# Use 'directory' column to store search_term for schema compatibility
|
||||
df['directory'] = extracted.apply(lambda x: x[1])
|
||||
df['drug_sequence'] = extracted.apply(lambda x: x[2])
|
||||
|
||||
logger.info(f"Extracted indication fields for {len(df)} nodes")
|
||||
logger.info(f" Unique trusts: {df['trust_name'].nunique()}")
|
||||
logger.info(f" Unique search_terms: {df['directory'].nunique()}")
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def convert_to_records(
|
||||
ice_df: pd.DataFrame,
|
||||
date_filter_id: str,
|
||||
refresh_id: Optional[str] = None,
|
||||
chart_type: ChartType = "directory",
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Convert ice_df to a list of dictionaries for SQLite insertion.
|
||||
|
||||
Maps ice_df columns to pathway_nodes table schema:
|
||||
- parents, ids, labels: Direct mapping
|
||||
- level: From ice_df['level']
|
||||
- value, cost, costpp, colour: Direct mapping
|
||||
- cost_pp_pa: From ice_df['cost_pp_pa']
|
||||
- first_seen, last_seen, first_seen_parent, last_seen_parent: Date columns
|
||||
- average_spacing: From ice_df['average_spacing']
|
||||
- average_administered: JSON serialization of list
|
||||
- avg_days: From ice_df['avg_days']
|
||||
- trust_name, directory, drug_sequence: Denormalized fields
|
||||
- date_filter_id: The filter combination ID
|
||||
- chart_type: "directory" or "indication"
|
||||
- data_refresh_id: Optional refresh tracking ID
|
||||
|
||||
Args:
|
||||
ice_df: DataFrame from generate_icicle_chart() with denormalized fields
|
||||
date_filter_id: The date filter combination ID (e.g., 'all_6mo')
|
||||
refresh_id: Optional refresh tracking ID
|
||||
chart_type: Chart type - "directory" (default) or "indication"
|
||||
|
||||
Returns:
|
||||
List of dictionaries ready for SQLite insertion
|
||||
"""
|
||||
records = []
|
||||
|
||||
for _, row in ice_df.iterrows():
|
||||
# Handle date formatting
|
||||
first_seen = None
|
||||
last_seen = None
|
||||
first_seen_parent = None
|
||||
last_seen_parent = None
|
||||
|
||||
if pd.notna(row.get('First seen')):
|
||||
if hasattr(row['First seen'], 'isoformat'):
|
||||
first_seen = row['First seen'].isoformat()
|
||||
else:
|
||||
first_seen = str(row['First seen'])
|
||||
|
||||
if pd.notna(row.get('Last seen')):
|
||||
if hasattr(row['Last seen'], 'isoformat'):
|
||||
last_seen = row['Last seen'].isoformat()
|
||||
else:
|
||||
last_seen = str(row['Last seen'])
|
||||
|
||||
if pd.notna(row.get('First seen (Parent)')):
|
||||
first_seen_parent = str(row['First seen (Parent)'])
|
||||
|
||||
if pd.notna(row.get('Last seen (Parent)')):
|
||||
last_seen_parent = str(row['Last seen (Parent)'])
|
||||
|
||||
# Handle average_administered (could be list, ndarray, or None)
|
||||
average_administered = None
|
||||
val = row.get('average_administered')
|
||||
if val is not None:
|
||||
# Check for scalar None-like values
|
||||
try:
|
||||
if pd.isna(val):
|
||||
average_administered = None
|
||||
elif isinstance(val, (list, np.ndarray)):
|
||||
average_administered = json.dumps(list(val) if hasattr(val, 'tolist') else val)
|
||||
else:
|
||||
average_administered = str(val)
|
||||
except (ValueError, TypeError):
|
||||
# pd.isna raises ValueError for arrays with >1 element
|
||||
# In that case, val is an array/list, so convert to JSON
|
||||
if hasattr(val, 'tolist'):
|
||||
average_administered = json.dumps(val.tolist())
|
||||
elif isinstance(val, list):
|
||||
average_administered = json.dumps(val)
|
||||
else:
|
||||
average_administered = str(val)
|
||||
|
||||
record = {
|
||||
'date_filter_id': date_filter_id,
|
||||
'chart_type': chart_type,
|
||||
'parents': str(row.get('parents', '')) if pd.notna(row.get('parents')) else '',
|
||||
'ids': str(row.get('ids', '')) if pd.notna(row.get('ids')) else '',
|
||||
'labels': str(row.get('labels', '')) if pd.notna(row.get('labels')) else '',
|
||||
'level': int(row.get('level', 0)) if pd.notna(row.get('level')) else 0,
|
||||
'value': int(row.get('value', 0)) if pd.notna(row.get('value')) else 0,
|
||||
'cost': float(row.get('cost', 0)) if pd.notna(row.get('cost')) else 0.0,
|
||||
'costpp': float(row.get('costpp')) if pd.notna(row.get('costpp')) else None,
|
||||
'cost_pp_pa': str(row.get('cost_pp_pa', '')) if pd.notna(row.get('cost_pp_pa')) else None,
|
||||
'colour': float(row.get('colour', 0)) if pd.notna(row.get('colour')) else 0.0,
|
||||
'first_seen': first_seen,
|
||||
'last_seen': last_seen,
|
||||
'first_seen_parent': first_seen_parent,
|
||||
'last_seen_parent': last_seen_parent,
|
||||
'average_spacing': str(row.get('average_spacing', '')) if pd.notna(row.get('average_spacing')) else None,
|
||||
'average_administered': average_administered,
|
||||
'avg_days': float(row['avg_days'].total_seconds() / 86400) if pd.notna(row.get('avg_days')) and hasattr(row.get('avg_days'), 'total_seconds') else (float(row.get('avg_days')) if pd.notna(row.get('avg_days')) else None),
|
||||
'trust_name': row.get('trust_name', '') if pd.notna(row.get('trust_name')) else None,
|
||||
'directory': row.get('directory', '') if pd.notna(row.get('directory')) else None,
|
||||
'drug_sequence': row.get('drug_sequence', '') if pd.notna(row.get('drug_sequence')) else None,
|
||||
'data_refresh_id': refresh_id,
|
||||
}
|
||||
records.append(record)
|
||||
|
||||
logger.info(f"Converted {len(records)} pathway nodes to records for {date_filter_id} ({chart_type})")
|
||||
return records
|
||||
|
||||
|
||||
def process_all_date_filters(
|
||||
df: pd.DataFrame,
|
||||
trust_filter: list[str],
|
||||
drug_filter: list[str],
|
||||
directory_filter: list[str],
|
||||
minimum_patients: int = 5,
|
||||
max_date: Optional[date] = None,
|
||||
refresh_id: Optional[str] = None,
|
||||
paths: Optional[PathConfig] = None,
|
||||
) -> dict[str, list[dict]]:
|
||||
"""
|
||||
Process pathway data for all 6 date filter combinations.
|
||||
|
||||
This is a convenience function that processes all DATE_FILTER_CONFIGS
|
||||
and returns a dictionary of records ready for SQLite insertion.
|
||||
|
||||
Args:
|
||||
df: Transformed DataFrame from fetch_and_transform_data()
|
||||
trust_filter: List of trust names to include
|
||||
drug_filter: List of drug names to include
|
||||
directory_filter: List of directories to include
|
||||
minimum_patients: Minimum patients to include a pathway
|
||||
max_date: Reference date for computing date ranges
|
||||
refresh_id: Optional refresh tracking ID
|
||||
paths: PathConfig for file paths
|
||||
|
||||
Returns:
|
||||
Dictionary mapping date_filter_id to list of record dicts
|
||||
e.g., {"all_6mo": [...], "all_12mo": [...], ...}
|
||||
"""
|
||||
if paths is None:
|
||||
paths = default_paths
|
||||
|
||||
results = {}
|
||||
|
||||
for config in DATE_FILTER_CONFIGS:
|
||||
logger.info(f"Processing date filter: {config.id}")
|
||||
|
||||
# Process pathway for this date filter
|
||||
ice_df = process_pathway_for_date_filter(
|
||||
df=df,
|
||||
config=config,
|
||||
trust_filter=trust_filter,
|
||||
drug_filter=drug_filter,
|
||||
directory_filter=directory_filter,
|
||||
minimum_patients=minimum_patients,
|
||||
max_date=max_date,
|
||||
paths=paths,
|
||||
)
|
||||
|
||||
if ice_df is None:
|
||||
logger.warning(f"Skipping {config.id} - no data")
|
||||
results[config.id] = []
|
||||
continue
|
||||
|
||||
# Extract denormalized fields
|
||||
ice_df = extract_denormalized_fields(ice_df)
|
||||
|
||||
# Convert to records
|
||||
records = convert_to_records(ice_df, config.id, refresh_id)
|
||||
results[config.id] = records
|
||||
|
||||
logger.info(f"Completed {config.id}: {len(records)} nodes")
|
||||
|
||||
total_records = sum(len(r) for r in results.values())
|
||||
logger.info(f"Total pathway nodes across all filters: {total_records}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# Export public API
|
||||
__all__ = [
|
||||
# Types
|
||||
"ChartType",
|
||||
# Data classes
|
||||
"DateFilterConfig",
|
||||
"DATE_FILTER_CONFIGS",
|
||||
# Core functions
|
||||
"compute_date_ranges",
|
||||
"fetch_and_transform_data",
|
||||
# Directory chart processing
|
||||
"process_pathway_for_date_filter",
|
||||
"extract_denormalized_fields",
|
||||
# Indication chart processing
|
||||
"process_indication_pathway_for_date_filter",
|
||||
"extract_indication_fields",
|
||||
# Common utilities
|
||||
"convert_to_records",
|
||||
"process_all_date_filters",
|
||||
]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,709 @@
|
||||
"""
|
||||
SQLite schema definitions for NHS High-Cost Drug Patient Pathway Analysis Tool.
|
||||
|
||||
Contains SQL strings for creating reference tables, fact tables, and indexes.
|
||||
Schema design supports:
|
||||
- Reference data from CSV files (drug names, organizations, directories)
|
||||
- Drug-directory mappings with single-valid-directory flag
|
||||
- Patient intervention facts with proper indexing
|
||||
- Cached aggregations for performance
|
||||
- File tracking for incremental updates
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
import sqlite3
|
||||
|
||||
from core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Reference Table Schemas
|
||||
# =============================================================================
|
||||
|
||||
REF_DRUG_NAMES_SCHEMA = """
|
||||
-- Mapping from raw drug names (as they appear in source data) to standardized names
|
||||
-- Source: data/drugnames.csv
|
||||
CREATE TABLE IF NOT EXISTS ref_drug_names (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
raw_name TEXT NOT NULL UNIQUE,
|
||||
standard_name TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Index for fast lookups during data transformation
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_drug_names_raw ON ref_drug_names(raw_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_drug_names_standard ON ref_drug_names(standard_name);
|
||||
"""
|
||||
|
||||
REF_ORGANIZATIONS_SCHEMA = """
|
||||
-- NHS organization codes and names
|
||||
-- Source: data/org_codes.csv
|
||||
CREATE TABLE IF NOT EXISTS ref_organizations (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
org_code TEXT NOT NULL UNIQUE,
|
||||
org_name TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Index for fast lookups by organization code
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_organizations_code ON ref_organizations(org_code);
|
||||
"""
|
||||
|
||||
REF_DIRECTORIES_SCHEMA = """
|
||||
-- Medical directories/specialties
|
||||
-- Source: data/directory_list.csv
|
||||
CREATE TABLE IF NOT EXISTS ref_directories (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
directory_name TEXT NOT NULL UNIQUE,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Index for fast lookups by directory name
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_directories_name ON ref_directories(directory_name);
|
||||
"""
|
||||
|
||||
REF_DRUG_DIRECTORY_MAP_SCHEMA = """
|
||||
-- Mapping from drug names to valid directories
|
||||
-- Source: data/drug_directory_list.csv
|
||||
-- A drug may map to multiple directories (one row per drug-directory pair)
|
||||
-- The is_single_valid flag indicates drugs with exactly ONE valid directory,
|
||||
-- which enables automatic directory assignment in department_identification()
|
||||
CREATE TABLE IF NOT EXISTS ref_drug_directory_map (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
drug_name TEXT NOT NULL,
|
||||
directory_name TEXT NOT NULL,
|
||||
is_single_valid BOOLEAN NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(drug_name, directory_name)
|
||||
);
|
||||
|
||||
-- Index for looking up directories by drug name (most common access pattern)
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_drug_directory_map_drug ON ref_drug_directory_map(drug_name);
|
||||
|
||||
-- Index for reverse lookup (find drugs by directory)
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_drug_directory_map_directory ON ref_drug_directory_map(directory_name);
|
||||
|
||||
-- Index for quick filtering of single-valid drugs
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_drug_directory_map_single ON ref_drug_directory_map(is_single_valid);
|
||||
"""
|
||||
|
||||
REF_DRUG_INDICATION_CLUSTERS_SCHEMA = """
|
||||
-- Mapping from drugs to SNOMED clusters for indication validation
|
||||
-- Source: data/drug_indication_clusters.csv
|
||||
-- Used to validate that patients have appropriate GP diagnoses for their prescribed drugs
|
||||
-- A drug may map to multiple clusters (one row per drug-indication-cluster combination)
|
||||
CREATE TABLE IF NOT EXISTS ref_drug_indication_clusters (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
drug_name TEXT NOT NULL,
|
||||
indication TEXT NOT NULL,
|
||||
cluster_id TEXT NOT NULL,
|
||||
cluster_description TEXT,
|
||||
nice_ta_reference TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(drug_name, indication, cluster_id)
|
||||
);
|
||||
|
||||
-- Index for looking up clusters by drug name (most common access pattern)
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_drug_indication_clusters_drug ON ref_drug_indication_clusters(drug_name);
|
||||
|
||||
-- Index for looking up drugs by cluster (for finding all drugs treating a condition)
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_drug_indication_clusters_cluster ON ref_drug_indication_clusters(cluster_id);
|
||||
|
||||
-- Index for looking up by indication text
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_drug_indication_clusters_indication ON ref_drug_indication_clusters(indication);
|
||||
"""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pathway Data Architecture Schemas
|
||||
# =============================================================================
|
||||
|
||||
PATHWAY_DATE_FILTERS_SCHEMA = """
|
||||
-- Stores the 6 pre-computed date filter combinations
|
||||
-- Each combination represents a different initiated/last_seen date range
|
||||
-- Used to efficiently query pre-computed pathway data
|
||||
CREATE TABLE IF NOT EXISTS pathway_date_filters (
|
||||
id TEXT PRIMARY KEY, -- e.g., 'all_6mo', '1yr_12mo'
|
||||
initiated_label TEXT NOT NULL, -- e.g., 'All years', 'Last 1 year', 'Last 2 years'
|
||||
last_seen_label TEXT NOT NULL, -- e.g., 'Last 6 months', 'Last 12 months'
|
||||
initiated_years INTEGER, -- NULL for 'All', 1, or 2
|
||||
last_seen_months INTEGER NOT NULL, -- 6 or 12
|
||||
is_default INTEGER DEFAULT 0, -- 1 for 'all_6mo' (default selection)
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Pre-populate the 6 combinations
|
||||
INSERT OR REPLACE INTO pathway_date_filters (id, initiated_label, last_seen_label, initiated_years, last_seen_months, is_default) VALUES
|
||||
('all_6mo', 'All years', 'Last 6 months', NULL, 6, 1),
|
||||
('all_12mo', 'All years', 'Last 12 months', NULL, 12, 0),
|
||||
('1yr_6mo', 'Last 1 year', 'Last 6 months', 1, 6, 0),
|
||||
('1yr_12mo', 'Last 1 year', 'Last 12 months', 1, 12, 0),
|
||||
('2yr_6mo', 'Last 2 years', 'Last 6 months', 2, 6, 0),
|
||||
('2yr_12mo', 'Last 2 years', 'Last 12 months', 2, 12, 0);
|
||||
"""
|
||||
|
||||
PATHWAY_NODES_SCHEMA = """
|
||||
-- Main pathway nodes table (one set per date filter + chart type combination)
|
||||
-- Stores pre-computed pathway hierarchy with all visualization data
|
||||
-- Designed for fast filtering by date_filter_id + chart_type + trust/directory/drug
|
||||
CREATE TABLE IF NOT EXISTS pathway_nodes (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
|
||||
-- Date filter combination this belongs to
|
||||
date_filter_id TEXT NOT NULL,
|
||||
|
||||
-- Chart type: "directory" (Trust→Directory→Drug) or "indication" (Trust→SearchTerm→Drug)
|
||||
chart_type TEXT NOT NULL DEFAULT 'directory',
|
||||
|
||||
-- Hierarchy structure (for icicle chart)
|
||||
parents TEXT NOT NULL, -- Parent node identifier
|
||||
ids TEXT NOT NULL, -- Unique node identifier (hierarchical path)
|
||||
labels TEXT NOT NULL, -- Display label
|
||||
level INTEGER NOT NULL, -- Hierarchy depth (0=root, 1=trust, 2=directory/search_term, 3+=drugs)
|
||||
|
||||
-- Patient counts (accurate for this date filter combination)
|
||||
value INTEGER NOT NULL DEFAULT 0, -- Patient count
|
||||
|
||||
-- Cost metrics
|
||||
cost REAL NOT NULL DEFAULT 0.0, -- Total cost
|
||||
costpp REAL, -- Cost per patient
|
||||
cost_pp_pa TEXT, -- Cost per patient per annum (formatted string)
|
||||
|
||||
-- Visualization
|
||||
colour REAL NOT NULL DEFAULT 0.0, -- Color value (proportion of parent)
|
||||
|
||||
-- Date ranges (for this node)
|
||||
first_seen TEXT, -- First intervention date (ISO format)
|
||||
last_seen TEXT, -- Last intervention date (ISO format)
|
||||
first_seen_parent TEXT, -- Earliest date in parent group
|
||||
last_seen_parent TEXT, -- Latest date in parent group
|
||||
|
||||
-- Treatment statistics
|
||||
average_spacing TEXT, -- Formatted treatment duration string
|
||||
average_administered TEXT, -- JSON array of average doses per drug
|
||||
avg_days REAL, -- Average treatment duration in days
|
||||
|
||||
-- Denormalized filter columns (for efficient WHERE clause filtering)
|
||||
trust_name TEXT, -- Extracted trust name from ids
|
||||
directory TEXT, -- Extracted directory from ids
|
||||
drug_sequence TEXT, -- Pipe-separated drug sequence from pathway
|
||||
|
||||
-- Metadata
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
data_refresh_id TEXT, -- Links to pathway_refresh_log
|
||||
|
||||
-- Unique per date filter + chart type + pathway
|
||||
UNIQUE(date_filter_id, chart_type, ids),
|
||||
FOREIGN KEY (date_filter_id) REFERENCES pathway_date_filters(id)
|
||||
);
|
||||
|
||||
-- Indexes for efficient filtering
|
||||
-- Primary filter: select by date_filter_id
|
||||
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_date_filter ON pathway_nodes(date_filter_id);
|
||||
|
||||
-- Chart type filter: for switching between directory and indication views
|
||||
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_chart_type ON pathway_nodes(date_filter_id, chart_type);
|
||||
|
||||
-- Level filter: often used with date_filter_id
|
||||
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_level ON pathway_nodes(date_filter_id, level);
|
||||
|
||||
-- Trust filter: for Trust dropdown filtering
|
||||
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_trust ON pathway_nodes(date_filter_id, trust_name);
|
||||
|
||||
-- Directory filter: for Directory dropdown filtering
|
||||
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_directory ON pathway_nodes(date_filter_id, directory);
|
||||
|
||||
-- Drug sequence filter: for drug filtering (uses LIKE '%DRUG%')
|
||||
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_drug_seq ON pathway_nodes(drug_sequence);
|
||||
|
||||
-- Parents filter: for finding children of a node
|
||||
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_parents ON pathway_nodes(date_filter_id, parents);
|
||||
|
||||
-- Composite index for common filter combination
|
||||
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_filter_composite
|
||||
ON pathway_nodes(date_filter_id, chart_type, trust_name, directory);
|
||||
"""
|
||||
|
||||
PATHWAY_REFRESH_LOG_SCHEMA = """
|
||||
-- Metadata table for tracking refresh status
|
||||
-- Tracks when pathway data was last refreshed from Snowflake
|
||||
CREATE TABLE IF NOT EXISTS pathway_refresh_log (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
refresh_id TEXT NOT NULL, -- Unique identifier for this refresh run
|
||||
started_at TEXT NOT NULL, -- ISO timestamp when refresh started
|
||||
completed_at TEXT, -- ISO timestamp when refresh completed (NULL if still running)
|
||||
status TEXT DEFAULT 'running', -- 'running', 'completed', 'failed'
|
||||
record_count INTEGER, -- Total pathway_nodes records created
|
||||
date_filter_counts TEXT, -- JSON: {"all_6mo": 1234, "all_12mo": 1567, ...}
|
||||
error_message TEXT, -- Error details if status='failed'
|
||||
snowflake_query_date_from TEXT, -- Start date of Snowflake query
|
||||
snowflake_query_date_to TEXT, -- End date of Snowflake query
|
||||
processing_duration_seconds REAL, -- How long the refresh took
|
||||
source_row_count INTEGER, -- Number of Snowflake rows fetched
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Index for finding latest refresh
|
||||
CREATE INDEX IF NOT EXISTS idx_pathway_refresh_log_started ON pathway_refresh_log(started_at DESC);
|
||||
|
||||
-- Index for finding by status
|
||||
CREATE INDEX IF NOT EXISTS idx_pathway_refresh_log_status ON pathway_refresh_log(status);
|
||||
"""
|
||||
|
||||
# Combined pathway schema
|
||||
PATHWAY_TABLES_SCHEMA = f"""
|
||||
-- Pathway Data Architecture Tables
|
||||
-- Pre-computed pathway data for fast Reflex filtering
|
||||
|
||||
{PATHWAY_DATE_FILTERS_SCHEMA}
|
||||
|
||||
{PATHWAY_NODES_SCHEMA}
|
||||
|
||||
{PATHWAY_REFRESH_LOG_SCHEMA}
|
||||
"""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Combined Schemas
|
||||
# =============================================================================
|
||||
|
||||
REFERENCE_TABLES_SCHEMA = f"""
|
||||
-- Reference Tables Schema
|
||||
-- Contains lookup data migrated from CSV files
|
||||
|
||||
{REF_DRUG_NAMES_SCHEMA}
|
||||
|
||||
{REF_ORGANIZATIONS_SCHEMA}
|
||||
|
||||
{REF_DIRECTORIES_SCHEMA}
|
||||
|
||||
{REF_DRUG_DIRECTORY_MAP_SCHEMA}
|
||||
|
||||
{REF_DRUG_INDICATION_CLUSTERS_SCHEMA}
|
||||
"""
|
||||
|
||||
ALL_TABLES_SCHEMA = f"""
|
||||
-- Complete Database Schema
|
||||
-- Reference tables + Pathway tables
|
||||
|
||||
{REFERENCE_TABLES_SCHEMA}
|
||||
|
||||
{PATHWAY_TABLES_SCHEMA}
|
||||
"""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Schema Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
def create_reference_tables(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Create all reference tables in the database.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
"""
|
||||
logger.info("Creating reference tables...")
|
||||
conn.executescript(REFERENCE_TABLES_SCHEMA)
|
||||
logger.info("Reference tables created successfully")
|
||||
|
||||
|
||||
def drop_reference_tables(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Drop all reference tables from the database.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Warning:
|
||||
This will delete all reference data. Use with caution.
|
||||
"""
|
||||
logger.warning("Dropping reference tables...")
|
||||
conn.executescript("""
|
||||
DROP TABLE IF EXISTS ref_drug_names;
|
||||
DROP TABLE IF EXISTS ref_organizations;
|
||||
DROP TABLE IF EXISTS ref_directories;
|
||||
DROP TABLE IF EXISTS ref_drug_directory_map;
|
||||
DROP TABLE IF EXISTS ref_drug_indication_clusters;
|
||||
""")
|
||||
logger.info("Reference tables dropped")
|
||||
|
||||
|
||||
def get_reference_table_counts(conn: sqlite3.Connection) -> dict[str, int]:
|
||||
"""
|
||||
Get row counts for all reference tables.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping table name to row count.
|
||||
"""
|
||||
tables = [
|
||||
"ref_drug_names",
|
||||
"ref_organizations",
|
||||
"ref_directories",
|
||||
"ref_drug_directory_map",
|
||||
"ref_drug_indication_clusters",
|
||||
]
|
||||
counts = {}
|
||||
|
||||
for table in tables:
|
||||
try:
|
||||
cursor = conn.execute(f"SELECT COUNT(*) FROM {table}")
|
||||
result = cursor.fetchone()
|
||||
counts[table] = result[0] if result else 0
|
||||
except sqlite3.OperationalError:
|
||||
counts[table] = 0
|
||||
|
||||
return counts
|
||||
|
||||
|
||||
def verify_reference_tables_exist(conn: sqlite3.Connection) -> list[str]:
|
||||
"""
|
||||
Verify that all reference tables exist.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Returns:
|
||||
List of missing table names. Empty list means all tables exist.
|
||||
"""
|
||||
required_tables = [
|
||||
"ref_drug_names",
|
||||
"ref_organizations",
|
||||
"ref_directories",
|
||||
"ref_drug_directory_map",
|
||||
"ref_drug_indication_clusters",
|
||||
]
|
||||
missing = []
|
||||
|
||||
for table in required_tables:
|
||||
cursor = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
||||
(table,)
|
||||
)
|
||||
if cursor.fetchone() is None:
|
||||
missing.append(table)
|
||||
|
||||
return missing
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pathway Table Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
def create_pathway_tables(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Create pathway data architecture tables in the database.
|
||||
|
||||
Creates:
|
||||
- pathway_date_filters: 6 pre-defined date filter combinations
|
||||
- pathway_nodes: Pre-computed pathway hierarchy data
|
||||
- pathway_refresh_log: Refresh tracking metadata
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
"""
|
||||
logger.info("Creating pathway tables...")
|
||||
conn.executescript(PATHWAY_TABLES_SCHEMA)
|
||||
logger.info("Pathway tables created successfully")
|
||||
|
||||
|
||||
def drop_pathway_tables(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Drop pathway data architecture tables from the database.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Warning:
|
||||
This will delete all pre-computed pathway data.
|
||||
"""
|
||||
logger.warning("Dropping pathway tables...")
|
||||
conn.executescript("""
|
||||
DROP TABLE IF EXISTS pathway_nodes;
|
||||
DROP TABLE IF EXISTS pathway_refresh_log;
|
||||
DROP TABLE IF EXISTS pathway_date_filters;
|
||||
""")
|
||||
logger.info("Pathway tables dropped")
|
||||
|
||||
|
||||
def get_pathway_table_counts(conn: sqlite3.Connection) -> dict[str, int]:
|
||||
"""
|
||||
Get row counts for pathway tables.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping table name to row count.
|
||||
"""
|
||||
tables = ["pathway_date_filters", "pathway_nodes", "pathway_refresh_log"]
|
||||
counts = {}
|
||||
|
||||
for table in tables:
|
||||
try:
|
||||
cursor = conn.execute(f"SELECT COUNT(*) FROM {table}")
|
||||
result = cursor.fetchone()
|
||||
counts[table] = result[0] if result else 0
|
||||
except sqlite3.OperationalError:
|
||||
# Table doesn't exist yet
|
||||
counts[table] = 0
|
||||
|
||||
return counts
|
||||
|
||||
|
||||
def verify_pathway_tables_exist(conn: sqlite3.Connection) -> list[str]:
|
||||
"""
|
||||
Verify that pathway tables exist.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Returns:
|
||||
List of missing table names. Empty list means all tables exist.
|
||||
"""
|
||||
required_tables = ["pathway_date_filters", "pathway_nodes", "pathway_refresh_log"]
|
||||
missing = []
|
||||
|
||||
for table in required_tables:
|
||||
cursor = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
||||
(table,)
|
||||
)
|
||||
if cursor.fetchone() is None:
|
||||
missing.append(table)
|
||||
|
||||
return missing
|
||||
|
||||
|
||||
def clear_pathway_nodes(conn: sqlite3.Connection, date_filter_id: str | None = None) -> int:
|
||||
"""
|
||||
Clear pathway nodes, optionally for a specific date filter.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
date_filter_id: If provided, only clear nodes for this date filter.
|
||||
If None, clear all pathway nodes.
|
||||
|
||||
Returns:
|
||||
Number of rows deleted.
|
||||
"""
|
||||
if date_filter_id:
|
||||
cursor = conn.execute(
|
||||
"DELETE FROM pathway_nodes WHERE date_filter_id = ?",
|
||||
(date_filter_id,)
|
||||
)
|
||||
else:
|
||||
cursor = conn.execute("DELETE FROM pathway_nodes")
|
||||
|
||||
deleted_count = cursor.rowcount
|
||||
conn.commit()
|
||||
logger.info(f"Cleared {deleted_count} pathway nodes")
|
||||
return deleted_count
|
||||
|
||||
|
||||
def get_pathway_refresh_status(conn: sqlite3.Connection) -> dict | None:
|
||||
"""
|
||||
Get the status of the most recent pathway refresh.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Returns:
|
||||
Dictionary with refresh status, or None if no refresh has been done.
|
||||
"""
|
||||
try:
|
||||
cursor = conn.execute("""
|
||||
SELECT refresh_id, started_at, completed_at, status, record_count,
|
||||
date_filter_counts, error_message, processing_duration_seconds
|
||||
FROM pathway_refresh_log
|
||||
ORDER BY started_at DESC
|
||||
LIMIT 1
|
||||
""")
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return {
|
||||
"refresh_id": row[0],
|
||||
"started_at": row[1],
|
||||
"completed_at": row[2],
|
||||
"status": row[3],
|
||||
"record_count": row[4],
|
||||
"date_filter_counts": row[5],
|
||||
"error_message": row[6],
|
||||
"processing_duration_seconds": row[7],
|
||||
}
|
||||
return None
|
||||
except sqlite3.OperationalError:
|
||||
# Table doesn't exist yet
|
||||
return None
|
||||
|
||||
|
||||
def migrate_pathway_nodes_chart_type(conn: sqlite3.Connection) -> tuple[bool, str]:
|
||||
"""
|
||||
Migrate pathway_nodes table to add chart_type column.
|
||||
|
||||
This migration:
|
||||
1. Checks if chart_type column already exists
|
||||
2. If not, adds it with DEFAULT 'directory'
|
||||
3. Updates existing rows to have 'directory' chart_type
|
||||
4. Adds index for efficient filtering
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Returns:
|
||||
Tuple of (success: bool, message: str)
|
||||
"""
|
||||
# Check if table exists
|
||||
cursor = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='pathway_nodes'"
|
||||
)
|
||||
if cursor.fetchone() is None:
|
||||
return True, "pathway_nodes table does not exist yet (will be created with chart_type column)"
|
||||
|
||||
# Check if chart_type column already exists
|
||||
cursor = conn.execute("PRAGMA table_info(pathway_nodes)")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
|
||||
if "chart_type" in columns:
|
||||
return True, "chart_type column already exists in pathway_nodes"
|
||||
|
||||
# Add chart_type column
|
||||
logger.info("Adding chart_type column to pathway_nodes table...")
|
||||
try:
|
||||
# Add column with default value
|
||||
conn.execute("""
|
||||
ALTER TABLE pathway_nodes
|
||||
ADD COLUMN chart_type TEXT NOT NULL DEFAULT 'directory'
|
||||
""")
|
||||
|
||||
# Create index for efficient filtering by chart type
|
||||
conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_chart_type
|
||||
ON pathway_nodes(date_filter_id, chart_type)
|
||||
""")
|
||||
|
||||
# Update existing composite index (need to drop and recreate)
|
||||
# Note: SQLite doesn't support DROP INDEX IF EXISTS in older versions,
|
||||
# so we use a try/except
|
||||
try:
|
||||
conn.execute("DROP INDEX idx_pathway_nodes_filter_composite")
|
||||
except sqlite3.OperationalError:
|
||||
pass # Index didn't exist
|
||||
|
||||
conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_filter_composite
|
||||
ON pathway_nodes(date_filter_id, chart_type, trust_name, directory)
|
||||
""")
|
||||
|
||||
# Need to recreate unique constraint since it changed
|
||||
# SQLite doesn't support ALTER TABLE to change constraints, but
|
||||
# since we're adding a column with a default value and the old
|
||||
# constraint was (date_filter_id, ids), the new constraint
|
||||
# (date_filter_id, chart_type, ids) will be satisfied by all existing
|
||||
# rows since they all have chart_type='directory'
|
||||
|
||||
conn.commit()
|
||||
logger.info("chart_type column added successfully")
|
||||
|
||||
# Count updated rows
|
||||
cursor = conn.execute("SELECT COUNT(*) FROM pathway_nodes")
|
||||
row_count = cursor.fetchone()[0]
|
||||
|
||||
return True, f"Added chart_type column, {row_count} existing rows set to 'directory'"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add chart_type column: {e}")
|
||||
return False, f"Migration failed: {e}"
|
||||
|
||||
|
||||
def migrate_refresh_log_source_row_count(conn: sqlite3.Connection) -> tuple[bool, str]:
|
||||
"""Add source_row_count column to pathway_refresh_log if it doesn't exist.
|
||||
|
||||
This column stores the Snowflake row count for display in the UI footer.
|
||||
"""
|
||||
cursor = conn.execute("PRAGMA table_info(pathway_refresh_log)")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
|
||||
if "source_row_count" in columns:
|
||||
return True, "source_row_count column already exists"
|
||||
|
||||
logger.info("Adding source_row_count column to pathway_refresh_log...")
|
||||
try:
|
||||
conn.execute("""
|
||||
ALTER TABLE pathway_refresh_log
|
||||
ADD COLUMN source_row_count INTEGER
|
||||
""")
|
||||
conn.commit()
|
||||
return True, "Added source_row_count column"
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add source_row_count column: {e}")
|
||||
return False, f"Migration failed: {e}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Combined Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
def create_all_tables(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Create all tables (reference + pathway) in the database.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
"""
|
||||
logger.info("Creating all database tables...")
|
||||
conn.executescript(ALL_TABLES_SCHEMA)
|
||||
logger.info("All tables created successfully")
|
||||
|
||||
|
||||
def drop_all_tables(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Drop all tables from the database.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Warning:
|
||||
This will delete all data. Use with extreme caution.
|
||||
"""
|
||||
logger.warning("Dropping all tables...")
|
||||
drop_pathway_tables(conn)
|
||||
drop_reference_tables(conn)
|
||||
logger.info("All tables dropped")
|
||||
|
||||
|
||||
def get_all_table_counts(conn: sqlite3.Connection) -> dict[str, int]:
|
||||
"""
|
||||
Get row counts for all tables.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping table name to row count.
|
||||
"""
|
||||
counts = {}
|
||||
counts.update(get_reference_table_counts(conn))
|
||||
counts.update(get_pathway_table_counts(conn))
|
||||
return counts
|
||||
|
||||
|
||||
def verify_all_tables_exist(conn: sqlite3.Connection) -> list[str]:
|
||||
"""
|
||||
Verify that all tables exist.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Returns:
|
||||
List of missing table names. Empty list means all tables exist.
|
||||
"""
|
||||
missing = []
|
||||
missing.extend(verify_reference_tables_exist(conn))
|
||||
missing.extend(verify_pathway_tables_exist(conn))
|
||||
return missing
|
||||
@@ -0,0 +1,797 @@
|
||||
"""
|
||||
Snowflake connector module for NHS Patient Pathway Analysis.
|
||||
|
||||
Provides connection handling with SSO browser authentication for NHS environments.
|
||||
Uses the externalbrowser authenticator which opens a browser window for NHS identity
|
||||
management authentication.
|
||||
|
||||
Usage:
|
||||
from data_processing.snowflake_connector import SnowflakeConnector, get_connector
|
||||
|
||||
# Using context manager (recommended)
|
||||
with get_connector() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM table LIMIT 10")
|
||||
results = cursor.fetchall()
|
||||
|
||||
# Manual connection management
|
||||
connector = SnowflakeConnector()
|
||||
try:
|
||||
conn = connector.connect()
|
||||
cursor = conn.cursor()
|
||||
# ... use cursor ...
|
||||
finally:
|
||||
connector.close()
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator, Optional, TYPE_CHECKING
|
||||
import time
|
||||
|
||||
# Snowflake connector is an optional dependency
|
||||
SNOWFLAKE_AVAILABLE = False
|
||||
try:
|
||||
import snowflake.connector
|
||||
from snowflake.connector import SnowflakeConnection
|
||||
from snowflake.connector.cursor import SnowflakeCursor
|
||||
SNOWFLAKE_AVAILABLE = True
|
||||
except ImportError:
|
||||
snowflake = None # type: ignore[assignment]
|
||||
|
||||
# Type hints for when snowflake is not available
|
||||
if TYPE_CHECKING:
|
||||
from snowflake.connector import SnowflakeConnection
|
||||
from snowflake.connector.cursor import SnowflakeCursor
|
||||
|
||||
from config import get_snowflake_config, SnowflakeConfig
|
||||
from core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SnowflakeConnectionError(Exception):
|
||||
"""Raised when Snowflake connection fails."""
|
||||
pass
|
||||
|
||||
|
||||
class SnowflakeNotConfiguredError(Exception):
|
||||
"""Raised when Snowflake is not configured (no account)."""
|
||||
pass
|
||||
|
||||
|
||||
class SnowflakeNotAvailableError(Exception):
|
||||
"""Raised when snowflake-connector-python is not installed."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionInfo:
|
||||
"""Information about the current connection state."""
|
||||
connected: bool = False
|
||||
account: str = ""
|
||||
warehouse: str = ""
|
||||
database: str = ""
|
||||
schema: str = ""
|
||||
user: str = ""
|
||||
role: str = ""
|
||||
connected_at: Optional[datetime] = None
|
||||
last_query_at: Optional[datetime] = None
|
||||
query_count: int = 0
|
||||
|
||||
|
||||
class SnowflakeConnector:
|
||||
"""
|
||||
Manages Snowflake connections with SSO browser authentication.
|
||||
|
||||
This class provides connection management for NHS Snowflake access using
|
||||
the externalbrowser authenticator which triggers NHS SSO login via browser.
|
||||
|
||||
Attributes:
|
||||
config: SnowflakeConfig with connection settings
|
||||
connection_info: ConnectionInfo tracking current state
|
||||
|
||||
Example:
|
||||
connector = SnowflakeConnector()
|
||||
with connector.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT CURRENT_USER()")
|
||||
print(cursor.fetchone()[0])
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[SnowflakeConfig] = None):
|
||||
"""
|
||||
Initialize the connector with configuration.
|
||||
|
||||
Args:
|
||||
config: Optional SnowflakeConfig. If not provided, loads from
|
||||
config/snowflake.toml using get_snowflake_config().
|
||||
"""
|
||||
self._config = config or get_snowflake_config()
|
||||
self._connection: Optional[SnowflakeConnection] = None
|
||||
self._connection_info = ConnectionInfo()
|
||||
|
||||
@property
|
||||
def config(self) -> SnowflakeConfig:
|
||||
"""Return the Snowflake configuration."""
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def connection_info(self) -> ConnectionInfo:
|
||||
"""Return information about the current connection state."""
|
||||
return self._connection_info
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Return True if currently connected to Snowflake."""
|
||||
return self._connection is not None and not self._connection.is_closed()
|
||||
|
||||
def _check_availability(self) -> None:
|
||||
"""Check that snowflake-connector-python is installed."""
|
||||
if not SNOWFLAKE_AVAILABLE:
|
||||
raise SnowflakeNotAvailableError(
|
||||
"snowflake-connector-python is not installed. "
|
||||
"Install it with: pip install snowflake-connector-python"
|
||||
)
|
||||
|
||||
def _check_configured(self) -> None:
|
||||
"""Check that Snowflake is configured."""
|
||||
if not self._config.is_configured:
|
||||
raise SnowflakeNotConfiguredError(
|
||||
"Snowflake account is not configured. "
|
||||
"Edit config/snowflake.toml and set connection.account"
|
||||
)
|
||||
|
||||
def connect(self) -> SnowflakeConnection:
|
||||
"""
|
||||
Establish a connection to Snowflake.
|
||||
|
||||
Uses the externalbrowser authenticator which opens a browser window
|
||||
for NHS SSO authentication. The browser popup is expected and normal.
|
||||
|
||||
Returns:
|
||||
Active SnowflakeConnection
|
||||
|
||||
Raises:
|
||||
SnowflakeNotAvailableError: If snowflake-connector-python not installed
|
||||
SnowflakeNotConfiguredError: If account is not configured
|
||||
SnowflakeConnectionError: If connection fails
|
||||
"""
|
||||
self._check_availability()
|
||||
self._check_configured()
|
||||
|
||||
# Close existing connection if any
|
||||
if self._connection is not None:
|
||||
self.close()
|
||||
|
||||
conn_cfg = self._config.connection
|
||||
timeout_cfg = self._config.timeouts
|
||||
|
||||
logger.info(f"Connecting to Snowflake account: {conn_cfg.account}")
|
||||
logger.info(f"Using warehouse: {conn_cfg.warehouse}, database: {conn_cfg.database}")
|
||||
logger.info(f"Authenticator: {conn_cfg.authenticator}")
|
||||
if conn_cfg.authenticator == "externalbrowser":
|
||||
logger.info("Browser window will open for NHS SSO authentication")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Build connection parameters
|
||||
connect_params = {
|
||||
"account": conn_cfg.account,
|
||||
"warehouse": conn_cfg.warehouse,
|
||||
"database": conn_cfg.database,
|
||||
"schema": conn_cfg.schema,
|
||||
"authenticator": conn_cfg.authenticator,
|
||||
"login_timeout": timeout_cfg.login_timeout,
|
||||
"network_timeout": timeout_cfg.connection_timeout,
|
||||
}
|
||||
|
||||
# Optional parameters (only add if set)
|
||||
if conn_cfg.user:
|
||||
connect_params["user"] = conn_cfg.user
|
||||
if conn_cfg.role:
|
||||
connect_params["role"] = conn_cfg.role
|
||||
|
||||
self._connection = snowflake.connector.connect(**connect_params)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"Connected to Snowflake successfully in {elapsed:.1f}s")
|
||||
|
||||
# Update connection info
|
||||
self._connection_info = ConnectionInfo(
|
||||
connected=True,
|
||||
account=conn_cfg.account,
|
||||
warehouse=conn_cfg.warehouse,
|
||||
database=conn_cfg.database,
|
||||
schema=conn_cfg.schema,
|
||||
user=self._get_current_user(),
|
||||
role=self._get_current_role(),
|
||||
connected_at=datetime.now(),
|
||||
query_count=0,
|
||||
)
|
||||
|
||||
return self._connection
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.time() - start_time
|
||||
logger.error(f"Failed to connect to Snowflake after {elapsed:.1f}s: {e}")
|
||||
self._connection_info = ConnectionInfo(connected=False)
|
||||
raise SnowflakeConnectionError(f"Failed to connect to Snowflake: {e}") from e
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the Snowflake connection if open."""
|
||||
if self._connection is not None:
|
||||
try:
|
||||
self._connection.close()
|
||||
logger.info("Snowflake connection closed")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing Snowflake connection: {e}")
|
||||
finally:
|
||||
self._connection = None
|
||||
self._connection_info = ConnectionInfo(connected=False)
|
||||
|
||||
def _get_current_user(self) -> str:
|
||||
"""Get the current authenticated user."""
|
||||
if self._connection is None:
|
||||
return ""
|
||||
try:
|
||||
cursor = self._connection.cursor()
|
||||
cursor.execute("SELECT CURRENT_USER()")
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else ""
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def _get_current_role(self) -> str:
|
||||
"""Get the current active role."""
|
||||
if self._connection is None:
|
||||
return ""
|
||||
try:
|
||||
cursor = self._connection.cursor()
|
||||
cursor.execute("SELECT CURRENT_ROLE()")
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else ""
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self) -> Generator[SnowflakeConnection, None, None]:
|
||||
"""
|
||||
Context manager for connection handling.
|
||||
|
||||
Creates a new connection if not already connected, yields the connection,
|
||||
and ensures proper cleanup on exit.
|
||||
|
||||
Yields:
|
||||
Active SnowflakeConnection
|
||||
|
||||
Example:
|
||||
connector = SnowflakeConnector()
|
||||
with connector.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
assert self._connection is not None, "Connection should be established"
|
||||
try:
|
||||
yield self._connection
|
||||
finally:
|
||||
# Keep connection open for reuse
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def get_cursor(
|
||||
self,
|
||||
dict_cursor: bool = False
|
||||
) -> Generator[SnowflakeCursor, None, None]:
|
||||
"""
|
||||
Context manager that provides a cursor.
|
||||
|
||||
Args:
|
||||
dict_cursor: If True, returns cursor that yields dict-like rows
|
||||
|
||||
Yields:
|
||||
SnowflakeCursor for executing queries
|
||||
|
||||
Example:
|
||||
connector = SnowflakeConnector()
|
||||
with connector.get_cursor() as cursor:
|
||||
cursor.execute("SELECT * FROM table LIMIT 10")
|
||||
for row in cursor:
|
||||
print(row)
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
assert self._connection is not None, "Connection should be established"
|
||||
cursor: Any = None
|
||||
try:
|
||||
if dict_cursor:
|
||||
cursor = self._connection.cursor(snowflake.connector.DictCursor) # type: ignore[union-attr]
|
||||
else:
|
||||
cursor = self._connection.cursor()
|
||||
yield cursor # type: ignore[misc]
|
||||
self._connection_info.last_query_at = datetime.now()
|
||||
self._connection_info.query_count += 1
|
||||
finally:
|
||||
if cursor is not None:
|
||||
cursor.close()
|
||||
|
||||
def execute(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[tuple] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> list[tuple]:
|
||||
"""
|
||||
Execute a query and return all results.
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
params: Optional query parameters for parameterized queries
|
||||
timeout: Optional query timeout in seconds (overrides config)
|
||||
|
||||
Returns:
|
||||
List of result rows as tuples
|
||||
|
||||
Raises:
|
||||
SnowflakeConnectionError: If not connected
|
||||
Various snowflake errors for query issues
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
effective_timeout = timeout or self._config.timeouts.query_timeout
|
||||
|
||||
with self.get_cursor() as cursor:
|
||||
logger.info(f"Executing query (timeout={effective_timeout}s)")
|
||||
logger.debug(f"Query: {query[:200]}...")
|
||||
|
||||
if effective_timeout > 0:
|
||||
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
|
||||
|
||||
start_time = time.time()
|
||||
cursor.execute(query, params)
|
||||
results = cursor.fetchall()
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
logger.info(f"Query returned {len(results)} rows in {elapsed:.2f}s")
|
||||
return results
|
||||
|
||||
def execute_dict(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[tuple] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Execute a query and return results as list of dictionaries.
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
params: Optional query parameters
|
||||
timeout: Optional query timeout in seconds
|
||||
|
||||
Returns:
|
||||
List of result rows as dictionaries
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
effective_timeout = timeout or self._config.timeouts.query_timeout
|
||||
|
||||
with self.get_cursor(dict_cursor=True) as cursor:
|
||||
logger.info(f"Executing query (timeout={effective_timeout}s)")
|
||||
logger.debug(f"Query: {query[:200]}...")
|
||||
|
||||
if effective_timeout > 0:
|
||||
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
|
||||
|
||||
start_time = time.time()
|
||||
cursor.execute(query, params)
|
||||
results = cursor.fetchall()
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
logger.info(f"Query returned {len(results)} rows in {elapsed:.2f}s")
|
||||
return results # type: ignore[return-value]
|
||||
|
||||
def execute_chunked(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[tuple] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
timeout: Optional[int] = None,
|
||||
max_rows: Optional[int] = None,
|
||||
) -> Generator[list[tuple], None, None]:
|
||||
"""
|
||||
Execute a query and yield results in chunks for memory efficiency.
|
||||
|
||||
This method is useful for large result sets that would exceed memory
|
||||
if loaded all at once. Results are yielded as chunks of rows.
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
params: Optional query parameters for parameterized queries
|
||||
chunk_size: Number of rows per chunk (default from config)
|
||||
timeout: Optional query timeout in seconds (overrides config)
|
||||
max_rows: Maximum total rows to return (default from config, 0 for no limit)
|
||||
|
||||
Yields:
|
||||
List of result rows as tuples for each chunk
|
||||
|
||||
Example:
|
||||
for chunk in connector.execute_chunked("SELECT * FROM large_table"):
|
||||
process_chunk(chunk)
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
effective_timeout = timeout or self._config.timeouts.query_timeout
|
||||
effective_chunk_size = chunk_size or self._config.query.chunk_size
|
||||
effective_max_rows = max_rows if max_rows is not None else self._config.query.max_rows
|
||||
|
||||
with self.get_cursor() as cursor:
|
||||
logger.info(f"Executing chunked query (chunk_size={effective_chunk_size}, timeout={effective_timeout}s)")
|
||||
logger.debug(f"Query: {query[:200]}...")
|
||||
|
||||
if effective_timeout > 0:
|
||||
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
|
||||
|
||||
start_time = time.time()
|
||||
cursor.execute(query, params)
|
||||
|
||||
total_rows = 0
|
||||
chunk_num = 0
|
||||
|
||||
while True:
|
||||
# Determine how many rows to fetch this chunk
|
||||
if effective_max_rows > 0:
|
||||
remaining = effective_max_rows - total_rows
|
||||
if remaining <= 0:
|
||||
break
|
||||
fetch_size = min(effective_chunk_size, remaining)
|
||||
else:
|
||||
fetch_size = effective_chunk_size
|
||||
|
||||
chunk = cursor.fetchmany(fetch_size)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
chunk_num += 1
|
||||
total_rows += len(chunk)
|
||||
logger.debug(f"Chunk {chunk_num}: {len(chunk)} rows (total: {total_rows})")
|
||||
yield chunk
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"Chunked query returned {total_rows} rows in {chunk_num} chunks ({elapsed:.2f}s)")
|
||||
|
||||
def execute_chunked_dict(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[tuple] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
timeout: Optional[int] = None,
|
||||
max_rows: Optional[int] = None,
|
||||
) -> Generator[list[dict], None, None]:
|
||||
"""
|
||||
Execute a query and yield dict results in chunks for memory efficiency.
|
||||
|
||||
Same as execute_chunked but returns rows as dictionaries.
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
params: Optional query parameters
|
||||
chunk_size: Number of rows per chunk (default from config)
|
||||
timeout: Optional query timeout in seconds
|
||||
max_rows: Maximum total rows to return (default from config, 0 for no limit)
|
||||
|
||||
Yields:
|
||||
List of result rows as dictionaries for each chunk
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
effective_timeout = timeout or self._config.timeouts.query_timeout
|
||||
effective_chunk_size = chunk_size or self._config.query.chunk_size
|
||||
effective_max_rows = max_rows if max_rows is not None else self._config.query.max_rows
|
||||
|
||||
with self.get_cursor(dict_cursor=True) as cursor:
|
||||
logger.info(f"Executing chunked dict query (chunk_size={effective_chunk_size}, timeout={effective_timeout}s)")
|
||||
logger.debug(f"Query: {query[:200]}...")
|
||||
|
||||
if effective_timeout > 0:
|
||||
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
|
||||
|
||||
start_time = time.time()
|
||||
cursor.execute(query, params)
|
||||
|
||||
total_rows = 0
|
||||
chunk_num = 0
|
||||
|
||||
while True:
|
||||
# Determine how many rows to fetch this chunk
|
||||
if effective_max_rows > 0:
|
||||
remaining = effective_max_rows - total_rows
|
||||
if remaining <= 0:
|
||||
break
|
||||
fetch_size = min(effective_chunk_size, remaining)
|
||||
else:
|
||||
fetch_size = effective_chunk_size
|
||||
|
||||
chunk = cursor.fetchmany(fetch_size)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
chunk_num += 1
|
||||
total_rows += len(chunk)
|
||||
logger.debug(f"Chunk {chunk_num}: {len(chunk)} rows (total: {total_rows})")
|
||||
yield chunk # type: ignore[misc]
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"Chunked dict query returned {total_rows} rows in {chunk_num} chunks ({elapsed:.2f}s)")
|
||||
|
||||
def execute_with_row_limit(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[tuple] = None,
|
||||
max_rows: Optional[int] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> tuple[list[dict], bool]:
|
||||
"""
|
||||
Execute a query with a row limit and indicate if more rows were available.
|
||||
|
||||
This is useful for pagination or previewing large result sets.
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
params: Optional query parameters
|
||||
max_rows: Maximum rows to return (default from config)
|
||||
timeout: Optional query timeout in seconds
|
||||
|
||||
Returns:
|
||||
Tuple of (results list, has_more bool)
|
||||
- results: List of result rows as dictionaries (up to max_rows)
|
||||
- has_more: True if there were more rows than max_rows
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
effective_timeout = timeout or self._config.timeouts.query_timeout
|
||||
effective_max_rows = max_rows if max_rows is not None else self._config.query.max_rows
|
||||
|
||||
with self.get_cursor(dict_cursor=True) as cursor:
|
||||
logger.info(f"Executing query with limit (max_rows={effective_max_rows}, timeout={effective_timeout}s)")
|
||||
logger.debug(f"Query: {query[:200]}...")
|
||||
|
||||
if effective_timeout > 0:
|
||||
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
|
||||
|
||||
start_time = time.time()
|
||||
cursor.execute(query, params)
|
||||
|
||||
# Fetch one more than max to detect if there are more rows
|
||||
results = cursor.fetchmany(effective_max_rows + 1)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
has_more = len(results) > effective_max_rows
|
||||
if has_more:
|
||||
results = results[:effective_max_rows]
|
||||
|
||||
logger.info(f"Query returned {len(results)} rows (has_more={has_more}) in {elapsed:.2f}s")
|
||||
return results, has_more # type: ignore[return-value]
|
||||
|
||||
def fetch_activity_data(
|
||||
self,
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
provider_codes: Optional[list[str]] = None,
|
||||
max_rows: Optional[int] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Fetch high-cost drug activity data from Snowflake.
|
||||
|
||||
Queries the CDM.Acute__Conmon__PatientLevelDrugs table and returns
|
||||
data in a format compatible with the existing analysis pipeline.
|
||||
|
||||
Args:
|
||||
start_date: Optional start date for filtering (inclusive)
|
||||
end_date: Optional end date for filtering (inclusive)
|
||||
provider_codes: Optional list of provider codes to filter by
|
||||
max_rows: Maximum rows to return (default from config)
|
||||
timeout: Query timeout in seconds (default from config)
|
||||
|
||||
Returns:
|
||||
List of dictionaries with keys matching expected DataFrame columns:
|
||||
- PseudoNHSNoLinked: Pseudonymised NHS number (for UPID creation)
|
||||
- Provider Code: NHS provider code
|
||||
- PersonKey: Local patient identifier
|
||||
- Drug Name: Raw drug name
|
||||
- Intervention Date: Date of intervention
|
||||
- Price Actual: Cost of intervention
|
||||
- OrganisationName: Provider organisation name
|
||||
- Treatment Function Code: NHS treatment function code
|
||||
- Additional Detail 1-5: Additional details for directory identification
|
||||
|
||||
Raises:
|
||||
SnowflakeConnectionError: If not connected or query fails
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
# Build the query
|
||||
table_name = 'DATA_HUB.CDM."Acute__Conmon__PatientLevelDrugs"'
|
||||
|
||||
query = f'''
|
||||
SELECT
|
||||
"PseudoNHSNoLinked",
|
||||
"ProviderCode" AS "Provider Code",
|
||||
"LocalPatientID" AS "PersonKey",
|
||||
"DrugName" AS "Drug Name",
|
||||
"InterventionDate" AS "Intervention Date",
|
||||
"PriceActual" AS "Price Actual",
|
||||
"ProviderName" AS "OrganisationName",
|
||||
"TreatmentFunctionCode" AS "Treatment Function Code",
|
||||
"TreatmentFunctionDesc" AS "Treatment Function Desc",
|
||||
"AdditionalDetail1" AS "Additional Detail 1",
|
||||
"AdditionalDescription1" AS "Additional Description 1",
|
||||
"AdditionalDetail2" AS "Additional Detail 2",
|
||||
"AdditionalDescription2" AS "Additional Description 2",
|
||||
"AdditionalDetail3" AS "Additional Detail 3",
|
||||
"AdditionalDescription3" AS "Additional Description 3",
|
||||
"AdditionalDetail4" AS "Additional Detail 4",
|
||||
"AdditionalDescription4" AS "Additional Description 4",
|
||||
"AdditionalDetail5" AS "Additional Detail 5",
|
||||
"AdditionalDescription5" AS "Additional Description 5"
|
||||
FROM {table_name}
|
||||
WHERE 1=1
|
||||
'''
|
||||
|
||||
params = []
|
||||
|
||||
# Add date filters
|
||||
if start_date:
|
||||
query += ' AND "InterventionDate" >= %s'
|
||||
params.append(start_date.isoformat())
|
||||
if end_date:
|
||||
query += ' AND "InterventionDate" <= %s'
|
||||
params.append(end_date.isoformat())
|
||||
|
||||
# Add provider filter
|
||||
if provider_codes:
|
||||
placeholders = ", ".join(["%s"] * len(provider_codes))
|
||||
query += f' AND "ProviderCode" IN ({placeholders})'
|
||||
params.extend(provider_codes)
|
||||
|
||||
# Add ordering for consistent results
|
||||
query += ' ORDER BY "InterventionDate", "ProviderCode", "PseudoNHSNoLinked"'
|
||||
|
||||
logger.info(f"Fetching activity data from Snowflake")
|
||||
if start_date:
|
||||
logger.info(f" Date range: {start_date} to {end_date or 'now'}")
|
||||
if provider_codes:
|
||||
logger.info(f" Providers: {provider_codes}")
|
||||
|
||||
effective_max_rows = max_rows if max_rows is not None else self._config.query.max_rows
|
||||
effective_timeout = timeout or self._config.timeouts.query_timeout
|
||||
|
||||
# Execute with chunked results for large datasets
|
||||
all_results = []
|
||||
total_rows = 0
|
||||
|
||||
for chunk in self.execute_chunked_dict(
|
||||
query,
|
||||
params=tuple(params) if params else None,
|
||||
timeout=effective_timeout,
|
||||
max_rows=effective_max_rows,
|
||||
):
|
||||
all_results.extend(chunk)
|
||||
total_rows += len(chunk)
|
||||
logger.debug(f"Fetched {total_rows} rows so far...")
|
||||
|
||||
logger.info(f"Fetched {len(all_results)} activity records from Snowflake")
|
||||
return all_results
|
||||
|
||||
def test_connection(self) -> tuple[bool, str]:
|
||||
"""
|
||||
Test the Snowflake connection.
|
||||
|
||||
Returns:
|
||||
Tuple of (success: bool, message: str)
|
||||
"""
|
||||
try:
|
||||
self._check_availability()
|
||||
except SnowflakeNotAvailableError as e:
|
||||
return False, str(e)
|
||||
|
||||
try:
|
||||
self._check_configured()
|
||||
except SnowflakeNotConfiguredError as e:
|
||||
return False, str(e)
|
||||
|
||||
try:
|
||||
self.connect()
|
||||
user = self._get_current_user()
|
||||
role = self._get_current_role()
|
||||
return True, f"Connected as {user} with role {role}"
|
||||
except Exception as e:
|
||||
return False, f"Connection failed: {e}"
|
||||
|
||||
def __enter__(self) -> "SnowflakeConnector":
|
||||
"""Context manager entry."""
|
||||
self.connect()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
"""Context manager exit."""
|
||||
self.close()
|
||||
|
||||
|
||||
# Module-level singleton for convenience
|
||||
_default_connector: Optional[SnowflakeConnector] = None
|
||||
|
||||
|
||||
def get_connector(config: Optional[SnowflakeConfig] = None) -> SnowflakeConnector:
|
||||
"""
|
||||
Get a Snowflake connector (creates singleton on first call).
|
||||
|
||||
Args:
|
||||
config: Optional configuration. If provided, creates new connector
|
||||
with this config. If None, uses/creates default connector.
|
||||
|
||||
Returns:
|
||||
SnowflakeConnector instance
|
||||
"""
|
||||
global _default_connector
|
||||
|
||||
if config is not None:
|
||||
# Custom config requested, create new connector
|
||||
return SnowflakeConnector(config)
|
||||
|
||||
if _default_connector is None:
|
||||
_default_connector = SnowflakeConnector()
|
||||
|
||||
return _default_connector
|
||||
|
||||
|
||||
def reset_connector() -> None:
|
||||
"""Reset the default connector (closes connection and clears singleton)."""
|
||||
global _default_connector
|
||||
|
||||
if _default_connector is not None:
|
||||
_default_connector.close()
|
||||
_default_connector = None
|
||||
|
||||
|
||||
def is_snowflake_available() -> bool:
|
||||
"""Return True if snowflake-connector-python is installed."""
|
||||
return SNOWFLAKE_AVAILABLE
|
||||
|
||||
|
||||
def is_snowflake_configured() -> bool:
|
||||
"""Return True if Snowflake account is configured."""
|
||||
try:
|
||||
config = get_snowflake_config()
|
||||
return config.is_configured
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# Export public API
|
||||
__all__ = [
|
||||
"SnowflakeConnector",
|
||||
"SnowflakeConnectionError",
|
||||
"SnowflakeNotConfiguredError",
|
||||
"SnowflakeNotAvailableError",
|
||||
"ConnectionInfo",
|
||||
"get_connector",
|
||||
"reset_connector",
|
||||
"is_snowflake_available",
|
||||
"is_snowflake_configured",
|
||||
"SNOWFLAKE_AVAILABLE",
|
||||
]
|
||||
@@ -0,0 +1,331 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import csv
|
||||
import urllib.request
|
||||
import io # Added for StringIO
|
||||
import re # Added for regex escape and word boundaries
|
||||
from typing import Optional
|
||||
|
||||
from core import PathConfig, default_paths
|
||||
from core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
def drug_names(df, paths: Optional[PathConfig] = None):
|
||||
# Generate dictionary to convert drug names from activity data to generic standardisation
|
||||
if paths is None:
|
||||
paths = default_paths
|
||||
|
||||
d = {}
|
||||
with open(paths.drugnames_csv, 'r', newline='') as f:
|
||||
reader = csv.reader(f, delimiter=',')
|
||||
for drug_name, generic in reader:
|
||||
d[drug_name.upper()] = generic.upper()
|
||||
|
||||
# Map drug names with dictionary generated earlier
|
||||
df["Drug Name"] = df["Drug Name"].str.upper().map(d)
|
||||
|
||||
# Remove (Left eye) or (Right eye) from Drug Name, including whitespace
|
||||
df["Drug Name"] = df["Drug Name"].str.replace(r'\(LEFT EYE\)', '', regex=True) # Escaped parentheses
|
||||
df["Drug Name"] = df["Drug Name"].str.replace(r'\(RIGHT EYE\)', '', regex=True) # Escaped parentheses
|
||||
df["Drug Name"] = df["Drug Name"].str.strip()
|
||||
return df
|
||||
|
||||
|
||||
def patient_id(df):
|
||||
# Generate unique patient ID
|
||||
df["UPID"] = df["Provider Code"].str[:3] + df["PersonKey"].astype(str)
|
||||
return df
|
||||
|
||||
|
||||
def compress_csv(filepath):
|
||||
df = pd.read_csv(filepath)
|
||||
compressed_path = filepath.replace(".csv", "_bz2.csv")
|
||||
df.to_csv(compressed_path, compression="bz2", index=False)
|
||||
return compressed_path
|
||||
|
||||
|
||||
def department_identification(df, paths: Optional[PathConfig] = None):
|
||||
# --- Setup ---
|
||||
if paths is None:
|
||||
paths = default_paths
|
||||
|
||||
# 1. Load directory_list.csv and prepare uppercase versions/pattern
|
||||
try:
|
||||
directory_df = pd.read_csv(paths.directory_list_csv)
|
||||
directory_list = directory_df["directory"].dropna().astype(str).tolist()
|
||||
if not directory_list:
|
||||
raise ValueError("directory_list.csv is empty or contains only NA values.")
|
||||
directory_list_upper = [d.upper() for d in directory_list]
|
||||
# Use word boundaries (\b) to avoid partial matches within words, escape special regex chars
|
||||
dir_pattern_upper = r'\b({})'.format('|'.join(map(re.escape, directory_list_upper)))
|
||||
except FileNotFoundError:
|
||||
logger.error(f"File not found: {paths.directory_list_csv}. Cannot extract directories.")
|
||||
return df
|
||||
except ValueError as e:
|
||||
logger.error(f"Error loading directory list: {e}")
|
||||
return df
|
||||
|
||||
# Simpler pattern for Primary_Source (no word boundaries)
|
||||
dir_pattern_primary_simple = r'({})'.format('|'.join(map(re.escape, directory_list_upper)))
|
||||
|
||||
# 2. Load treatment_function_codes.csv and prepare uppercase mapping
|
||||
treatment_codes = pd.read_csv(paths.treatment_function_codes_csv)
|
||||
mapping_treatment_codes = dict(treatment_codes[['Code', 'Service']].values)
|
||||
mapping_treatment_codes_upper = {k: str(v).upper() for k, v in mapping_treatment_codes.items()}
|
||||
|
||||
# 3. Load drug_directory_list.csv and parse into drug_to_valid_dirs
|
||||
drug_to_valid_dirs: dict[str, set[str]] = {}
|
||||
# Try pandas direct read - much simpler approach
|
||||
drug_dir_df = pd.read_csv(paths.drug_directory_list_csv, skipinitialspace=True)
|
||||
|
||||
# Identify the drug name column (first column) and directory column (second column)
|
||||
drug_col = drug_dir_df.columns[0]
|
||||
dir_col = drug_dir_df.columns[1]
|
||||
|
||||
# Process dataframe directly
|
||||
drug_to_valid_dirs = {}
|
||||
for _, row in drug_dir_df.iterrows():
|
||||
drug_name = str(row[drug_col]).strip().upper()
|
||||
try:
|
||||
# Directories are pipe-separated in the second column
|
||||
dirs_str = str(row[dir_col]) if not pd.isna(row[dir_col]) else ""
|
||||
dirs = {d.strip().upper() for d in dirs_str.split('|') if d.strip()}
|
||||
if drug_name and dirs and drug_name.lower() != 'nan':
|
||||
drug_to_valid_dirs[drug_name] = dirs
|
||||
except Exception:
|
||||
# Silently continue on row errors
|
||||
continue
|
||||
# 4. Create drug_to_single_dir map
|
||||
drug_to_single_dir = {
|
||||
drug: list(dirs)[0]
|
||||
for drug, dirs in drug_to_valid_dirs.items()
|
||||
if len(dirs) == 1
|
||||
}
|
||||
|
||||
# --- Data Preprocessing ---
|
||||
# Keep original extraction columns list
|
||||
additional_detail_columns = ["Additional Detail 1", "Additional Description 1", "Additional Detail 2", "Additional Description 2",
|
||||
"Additional Detail 3", "Additional Description 3", "Additional Detail 4", "Additional Description 4",
|
||||
"Additional Detail 5", "Additional Description 5", "NCDR Treatment Function Name", "Treatment Function Desc"]
|
||||
|
||||
# 6. Convert detail columns to uppercase BEFORE extraction
|
||||
for ad in additional_detail_columns:
|
||||
# Check if column exists and is object/string type before applying .str
|
||||
if ad in df.columns and pd.api.types.is_object_dtype(df[ad]):
|
||||
df[ad] = df[ad].str.upper()
|
||||
|
||||
# Original extraction loop (using original case list for extraction)
|
||||
# Extract directory from specified columns
|
||||
directory_df = pd.read_csv(paths.directory_list_csv)
|
||||
directory_list = directory_df["directory"].tolist() # Reload original case list
|
||||
|
||||
for ad in additional_detail_columns:
|
||||
try:
|
||||
# Ensure column is string type before cleaning
|
||||
if pd.api.types.is_string_dtype(df[ad]):
|
||||
# Extract directly from the uppercased string column
|
||||
extracted = df[ad].str.extract(dir_pattern_upper, expand=False)
|
||||
df.loc[extracted.index, ad] = extracted
|
||||
else:
|
||||
df[ad] = np.nan # Set non-string columns to NaN
|
||||
except AttributeError: # Skip columns that might not exist or are not string type
|
||||
df[ad] = np.nan # Ensure column exists but set to NaN if error
|
||||
except Exception as e: # Catch other potential errors during extract
|
||||
logger.error(f"Error processing column {ad}: {e}")
|
||||
df[ad] = np.nan
|
||||
|
||||
# 7. Process Treatment Function Code
|
||||
df["Treatment Function Code"].replace(np.nan, 0, inplace=True)
|
||||
# Ensure it's int type before mapping, handle potential errors
|
||||
try:
|
||||
df["Treatment Function Code"] = df["Treatment Function Code"].astype(int)
|
||||
except ValueError:
|
||||
# Handle cases where conversion to int fails (e.g., non-numeric values)
|
||||
# Try coercing errors to NaN, then fillna with 0
|
||||
df["Treatment Function Code"] = pd.to_numeric(df["Treatment Function Code"], errors='coerce').fillna(0).astype(int)
|
||||
|
||||
df["Treatment Function Code"] = df["Treatment Function Code"].map(mapping_treatment_codes_upper)
|
||||
df.rename(columns={'Treatment Function Code': 'Fallback_Source'}, inplace=True)
|
||||
|
||||
# Apply replacements before combining
|
||||
df.replace('MEDICAL OPHTHALMOLOGY', 'OPHTHALMOLOGY', inplace=True)
|
||||
|
||||
# --- Single Directory Assignment ---
|
||||
# 8. Apply single directory override
|
||||
# Ensure Drug Name is suitable for mapping (already done in drug_names func)
|
||||
df['Directory'] = df['Drug Name'].map(drug_to_single_dir)
|
||||
|
||||
# Initialize Directory_Source column - track which fallback level was used
|
||||
df['Directory_Source'] = pd.NA
|
||||
# Mark rows where single valid directory was assigned
|
||||
df.loc[df['Directory'].notna(), 'Directory_Source'] = 'SINGLE_VALID_DIR'
|
||||
|
||||
# --- Prepare Fallback Logic ---
|
||||
# 9. Create Primary source from Additional Detail 1
|
||||
if 'Additional Detail 1' in df.columns:
|
||||
df['Primary_Source'] = df['Additional Detail 1'].astype(pd.StringDtype())
|
||||
df['Primary_Source'] = df['Primary_Source'].str.upper() # Apply upper to strings
|
||||
else:
|
||||
df['Primary_Source'] = pd.NA # Use pd.NA for StringDtype
|
||||
|
||||
# Extract actual directory name using the pattern
|
||||
try:
|
||||
# Use simpler pattern for primary source
|
||||
df['Extracted_Primary_Dir'] = df['Primary_Source'].str.extract(dir_pattern_primary_simple, expand=False, flags=re.IGNORECASE)
|
||||
df['Extracted_Fallback_Dir'] = df['Fallback_Source'].str.extract(dir_pattern_upper, expand=False, flags=re.IGNORECASE)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during directory extraction: {e}")
|
||||
# Assign NA columns if extraction fails
|
||||
df['Extracted_Primary_Dir'] = pd.NA
|
||||
df['Extracted_Fallback_Dir'] = pd.NA
|
||||
|
||||
# Strip potential whitespace from extracted directories
|
||||
if 'Extracted_Primary_Dir' in df.columns:
|
||||
df['Extracted_Primary_Dir'] = df['Extracted_Primary_Dir'].str.strip()
|
||||
if 'Extracted_Fallback_Dir' in df.columns:
|
||||
df['Extracted_Fallback_Dir'] = df['Extracted_Fallback_Dir'].str.strip()
|
||||
|
||||
# 10. Combine sources, prioritizing Primary_Source
|
||||
# Combine EXTRACTED directories
|
||||
df['Primary_Directory'] = df['Extracted_Primary_Dir'].fillna(df['Extracted_Fallback_Dir'])
|
||||
|
||||
# Track extraction source for Directory_Source column
|
||||
# Rows where we have Extracted_Primary_Dir will use EXTRACTED_PRIMARY
|
||||
# Rows where we only have Extracted_Fallback_Dir will use EXTRACTED_FALLBACK
|
||||
df['_extracted_source'] = pd.NA
|
||||
df.loc[df['Extracted_Primary_Dir'].notna(), '_extracted_source'] = 'EXTRACTED_PRIMARY'
|
||||
df.loc[(df['Extracted_Primary_Dir'].isna()) & (df['Extracted_Fallback_Dir'].notna()), '_extracted_source'] = 'EXTRACTED_FALLBACK'
|
||||
|
||||
# 11. Clean up intermediate columns
|
||||
df.drop(columns=['Primary_Source', 'Fallback_Source', 'Extracted_Primary_Dir', 'Extracted_Fallback_Dir'], inplace=True, errors='ignore')
|
||||
|
||||
# --- Identify Rows Needing Calculation ---
|
||||
# 12. Filter rows where Directory is not yet assigned
|
||||
df_to_process = df[df['Directory'].isnull()].copy()
|
||||
|
||||
# --- Calculate Most Frequent Valid Directory ---
|
||||
# 13. Drop rows without a potential primary directory
|
||||
df_to_process.dropna(subset=['Primary_Directory'], inplace=True)
|
||||
|
||||
# 14. Group and count potential directories
|
||||
if not df_to_process.empty:
|
||||
df_counts = df_to_process.groupby(['UPID', 'Drug Name', 'Primary_Directory'], observed=True)['Primary_Directory'].count().reset_index(name='count')
|
||||
|
||||
# 15. Sort by count descending
|
||||
df_counts.sort_values(['UPID', 'Drug Name', 'count'], ascending=[True, True, False], inplace=True)
|
||||
|
||||
# 16. Define helper function
|
||||
def find_first_valid_dir(group, drug_map):
|
||||
drug_name = group['Drug Name'].iloc[0]
|
||||
valid_dirs = drug_map.get(drug_name, set())
|
||||
|
||||
if not valid_dirs:
|
||||
return np.nan
|
||||
|
||||
for dir_candidate in group['Primary_Directory']:
|
||||
# Skip NA values
|
||||
if pd.isna(dir_candidate):
|
||||
continue
|
||||
|
||||
# Check if valid directory for this drug
|
||||
if isinstance(dir_candidate, str) and dir_candidate in valid_dirs:
|
||||
return dir_candidate
|
||||
|
||||
return np.nan # No valid directory found in the group
|
||||
|
||||
# 17. Group by UPID and Drug Name
|
||||
valid_groups = df_counts.groupby(['UPID', 'Drug Name'], observed=True, group_keys=False)
|
||||
|
||||
# 18. Apply helper function to find the best valid directory
|
||||
calculated_dirs = valid_groups.apply(lambda grp: find_first_valid_dir(grp, drug_to_valid_dirs))
|
||||
|
||||
# 19. Reset index to get UPID, Drug Name columns
|
||||
final_mapping = calculated_dirs.reset_index()
|
||||
|
||||
# 20. Rename the resulting column
|
||||
final_mapping.columns = ['UPID', 'Drug Name', 'Calculated_Directory']
|
||||
|
||||
# --- Merge Results and Finalize ---
|
||||
# 21. Merge calculated directories back to the main DataFrame
|
||||
df = pd.merge(df, final_mapping, on=['UPID', 'Drug Name'], how='left')
|
||||
|
||||
# 22. Fill NaN Directories with the calculated ones and track source
|
||||
# Find rows that will be filled from Calculated_Directory
|
||||
rows_to_fill = df['Directory'].isna() & df['Calculated_Directory'].notna()
|
||||
# For these rows, set Directory_Source based on _extracted_source (where the calculated dir came from)
|
||||
# The "calculated" directory is still derived from extraction, just via frequency analysis
|
||||
df.loc[rows_to_fill, 'Directory_Source'] = df.loc[rows_to_fill, '_extracted_source'].fillna('CALCULATED_MOST_FREQ')
|
||||
# Replace with the actual value of _extracted_source or fall back to CALCULATED_MOST_FREQ
|
||||
# Actually, let's simplify: if we're using the calculated most frequent directory, that's CALCULATED_MOST_FREQ
|
||||
df.loc[rows_to_fill, 'Directory_Source'] = 'CALCULATED_MOST_FREQ'
|
||||
|
||||
df['Directory'].fillna(df['Calculated_Directory'], inplace=True)
|
||||
|
||||
# 23. Drop temporary columns
|
||||
df.drop(columns=['Calculated_Directory', 'Primary_Directory', '_extracted_source'], inplace=True, errors='ignore')
|
||||
|
||||
else:
|
||||
# If df_to_process was empty, still need to drop temporary columns
|
||||
df.drop(columns=['Primary_Directory', '_extracted_source'], inplace=True, errors='ignore')
|
||||
|
||||
# 24. Drop rows with missing UPID (original logic)
|
||||
df['UPID'].replace('', np.nan, inplace=True) # Ensure empty strings are NaN
|
||||
df_orig = df.copy() # Save before dropna for future reference if needed
|
||||
df.dropna(subset=['UPID'], inplace=True)
|
||||
|
||||
# 25. Export rows with NA Directory to CSV for analysis (keep this for diagnostics)
|
||||
na_directory_rows = df[df['Directory'].isna()].copy()
|
||||
|
||||
# Export to CSV if there are any NA Directory rows
|
||||
if len(na_directory_rows) > 0:
|
||||
na_directory_rows.to_csv(paths.na_directory_rows_csv, index=False)
|
||||
|
||||
# 26. FALLBACK MECHANISM 1: Infer directory based on same UPID
|
||||
# Create a mapping of most frequent directory per UPID (only for UPIDs with a directory)
|
||||
if len(df[df['Directory'].isna()]) > 0:
|
||||
# First get valid directories per UPID
|
||||
valid_upid_dirs = df[df['Directory'].notna()].groupby('UPID')['Directory'].agg(
|
||||
lambda x: x.value_counts().index[0] if len(x.value_counts()) > 0 else None
|
||||
).to_dict()
|
||||
|
||||
# Apply UPID-based inference and track source
|
||||
for idx in df[df['Directory'].isna()].index:
|
||||
upid = df.loc[idx, 'UPID']
|
||||
if upid in valid_upid_dirs and valid_upid_dirs[upid] is not None:
|
||||
df.loc[idx, 'Directory'] = valid_upid_dirs[upid]
|
||||
df.loc[idx, 'Directory_Source'] = 'UPID_INFERENCE'
|
||||
|
||||
# 27. FALLBACK MECHANISM 2: Label remaining NA as "Undefined"
|
||||
# Track rows that will be marked as Undefined
|
||||
rows_undefined = df['Directory'].isna()
|
||||
df.loc[rows_undefined, 'Directory_Source'] = 'UNDEFINED'
|
||||
# Fill remaining NA directories with "Undefined"
|
||||
df['Directory'].fillna("Undefined", inplace=True)
|
||||
|
||||
# 28. Return the processed DataFrame
|
||||
return df
|
||||
|
||||
|
||||
|
||||
def ta_list_get(paths: Optional[PathConfig] = None):
|
||||
if paths is None:
|
||||
paths = default_paths
|
||||
|
||||
link = "https://www.nice.org.uk/Media/Default/About/what-we-do/NICE-guidance/NICE-technology-appraisals/TA%20recommendations.xlsx"
|
||||
urllib.request.urlretrieve(link, paths.ta_recommendations_xlsx)
|
||||
ta_db = pd.read_excel(paths.ta_recommendations_xlsx, index_col=0)
|
||||
|
||||
# Filter out TA's which are not Recommended or not Pharmaceutical
|
||||
ta_db = ta_db[ta_db["Categorisation (for specific recommendation)"].isin(["Recommended", "Optimised"])]
|
||||
ta_db = ta_db[ta_db["Technology type"] == "Pharmaceutical"]
|
||||
|
||||
# Amend TA001 strings to only the integer
|
||||
ta_db["TA ID"] = ta_db["TA ID"].str.replace(r'\D+', '', regex=True).astype(int)
|
||||
ta_db["TA ID"] = "NICE TA" + ta_db["TA ID"].astype(str)
|
||||
ta_series = ta_db[["TA ID", "Indication"]].drop_duplicates()
|
||||
return ta_series
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user