refactor: slim pathways.db from 351 MB to 3.5 MB by removing unused tables

Drop fact_interventions (440K rows), mv_patient_treatment_summary (35K rows),
ref_drug_snomed_mapping (144K rows), and processed_files — all unused since
the app moved to pre-computed pathway_nodes.

Key changes:
- Rewrite load_data() to source from pathway_nodes + pathway_refresh_log
- Remove 7 dead methods and 8 dead state vars from pathways_app.py
- Delete patient_data.py, load_snomed_mapping.py, test_large_dataset_performance.py
- Remove SQLiteDataLoader (depended on fact_interventions)
- Remove file tracking schema (processed_files tracked fact_interventions loads)
- Remove legacy diagnosis functions from diagnosis_lookup.py
- Add source_row_count migration for pathway_refresh_log
- Clean all cross-references in __init__.py, data_source.py, migrate.py
This commit is contained in:
Andrew Charlwood
2026-02-06 08:51:03 +00:00
parent bb93c1673e
commit 778ed99ef6
11 changed files with 95 additions and 3653 deletions
+5 -1
View File
@@ -176,6 +176,7 @@ def log_refresh_complete(
record_count: int, record_count: int,
date_filter_counts: dict[str, int], date_filter_counts: dict[str, int],
duration_seconds: float, duration_seconds: float,
source_row_count: Optional[int] = None,
) -> None: ) -> None:
"""Log the successful completion of a refresh operation.""" """Log the successful completion of a refresh operation."""
conn.execute(""" conn.execute("""
@@ -184,13 +185,15 @@ def log_refresh_complete(
status = 'completed', status = 'completed',
record_count = ?, record_count = ?,
date_filter_counts = ?, date_filter_counts = ?,
processing_duration_seconds = ? processing_duration_seconds = ?,
source_row_count = ?
WHERE refresh_id = ? WHERE refresh_id = ?
""", ( """, (
datetime.now().isoformat(), datetime.now().isoformat(),
record_count, record_count,
json.dumps(date_filter_counts), json.dumps(date_filter_counts),
duration_seconds, duration_seconds,
source_row_count,
refresh_id, refresh_id,
)) ))
conn.commit() conn.commit()
@@ -517,6 +520,7 @@ def refresh_pathways(
record_count=stats["total_records"], record_count=stats["total_records"],
date_filter_counts=stats["date_filter_counts"], date_filter_counts=stats["date_filter_counts"],
duration_seconds=elapsed, duration_seconds=elapsed,
source_row_count=stats.get("snowflake_rows"),
) )
# Verify final counts # Verify final counts
-65
View File
@@ -24,15 +24,6 @@ from data_processing.schema import (
REF_DRUG_DIRECTORY_MAP_SCHEMA, REF_DRUG_DIRECTORY_MAP_SCHEMA,
REF_DRUG_INDICATION_CLUSTERS_SCHEMA, REF_DRUG_INDICATION_CLUSTERS_SCHEMA,
REFERENCE_TABLES_SCHEMA, REFERENCE_TABLES_SCHEMA,
# Fact table schemas
FACT_INTERVENTIONS_SCHEMA,
FACT_TABLES_SCHEMA,
# Materialized view schemas
MV_PATIENT_TREATMENT_SUMMARY_SCHEMA,
MATERIALIZED_VIEWS_SCHEMA,
# File tracking schemas
PROCESSED_FILES_SCHEMA,
FILE_TRACKING_SCHEMA,
# Combined schema # Combined schema
ALL_TABLES_SCHEMA, ALL_TABLES_SCHEMA,
# Reference table functions # Reference table functions
@@ -40,16 +31,6 @@ from data_processing.schema import (
drop_reference_tables, drop_reference_tables,
get_reference_table_counts, get_reference_table_counts,
verify_reference_tables_exist, verify_reference_tables_exist,
# Fact table functions
create_fact_tables,
drop_fact_tables,
get_fact_table_counts,
verify_fact_tables_exist,
# File tracking functions
create_file_tracking_tables,
drop_file_tracking_tables,
get_file_tracking_counts,
verify_file_tracking_tables_exist,
# Combined functions # Combined functions
create_all_tables, create_all_tables,
drop_all_tables, drop_all_tables,
@@ -81,27 +62,12 @@ from data_processing.reference_data import (
from data_processing.loader import ( from data_processing.loader import (
DataLoader, DataLoader,
FileDataLoader, FileDataLoader,
SQLiteDataLoader,
LoadResult, LoadResult,
get_loader, get_loader,
REQUIRED_COLUMNS, REQUIRED_COLUMNS,
OPTIONAL_COLUMNS, OPTIONAL_COLUMNS,
) )
# Patient data migration functions
from data_processing.patient_data import (
PatientDataLoadResult,
load_patient_data,
get_patient_data_stats,
list_processed_files,
calculate_file_hash,
# Materialized view functions
MVRefreshResult,
refresh_patient_treatment_summary,
get_patient_summary_stats,
verify_mv_consistency,
)
# Snowflake connector # Snowflake connector
from data_processing.snowflake_connector import ( from data_processing.snowflake_connector import (
SnowflakeConnector, SnowflakeConnector,
@@ -165,15 +131,6 @@ __all__ = [
"REF_DRUG_DIRECTORY_MAP_SCHEMA", "REF_DRUG_DIRECTORY_MAP_SCHEMA",
"REF_DRUG_INDICATION_CLUSTERS_SCHEMA", "REF_DRUG_INDICATION_CLUSTERS_SCHEMA",
"REFERENCE_TABLES_SCHEMA", "REFERENCE_TABLES_SCHEMA",
# Fact table schemas
"FACT_INTERVENTIONS_SCHEMA",
"FACT_TABLES_SCHEMA",
# Materialized view schemas
"MV_PATIENT_TREATMENT_SUMMARY_SCHEMA",
"MATERIALIZED_VIEWS_SCHEMA",
# File tracking schemas
"PROCESSED_FILES_SCHEMA",
"FILE_TRACKING_SCHEMA",
# Combined schema # Combined schema
"ALL_TABLES_SCHEMA", "ALL_TABLES_SCHEMA",
# Reference table functions # Reference table functions
@@ -181,16 +138,6 @@ __all__ = [
"drop_reference_tables", "drop_reference_tables",
"get_reference_table_counts", "get_reference_table_counts",
"verify_reference_tables_exist", "verify_reference_tables_exist",
# Fact table functions
"create_fact_tables",
"drop_fact_tables",
"get_fact_table_counts",
"verify_fact_tables_exist",
# File tracking functions
"create_file_tracking_tables",
"drop_file_tracking_tables",
"get_file_tracking_counts",
"verify_file_tracking_tables_exist",
# Combined functions # Combined functions
"create_all_tables", "create_all_tables",
"drop_all_tables", "drop_all_tables",
@@ -216,22 +163,10 @@ __all__ = [
# Data loader abstractions # Data loader abstractions
"DataLoader", "DataLoader",
"FileDataLoader", "FileDataLoader",
"SQLiteDataLoader",
"LoadResult", "LoadResult",
"get_loader", "get_loader",
"REQUIRED_COLUMNS", "REQUIRED_COLUMNS",
"OPTIONAL_COLUMNS", "OPTIONAL_COLUMNS",
# Patient data migration
"PatientDataLoadResult",
"load_patient_data",
"get_patient_data_stats",
"list_processed_files",
"calculate_file_hash",
# Materialized view functions
"MVRefreshResult",
"refresh_patient_treatment_summary",
"get_patient_summary_stats",
"verify_mv_consistency",
# Snowflake connector # Snowflake connector
"SnowflakeConnector", "SnowflakeConnector",
"SnowflakeConnectionError", "SnowflakeConnectionError",
+13 -49
View File
@@ -232,9 +232,9 @@ class DataSourceManager:
) )
def _check_sqlite_status(self) -> SourceStatus: def _check_sqlite_status(self) -> SourceStatus:
"""Check if SQLite database is available with data.""" """Check if SQLite database is available with pathway data."""
try: try:
from data_processing.database import default_db_manager, default_db_config from data_processing.database import default_db_config
db_path = self._sqlite_db_path or Path(default_db_config.db_path) db_path = self._sqlite_db_path or Path(default_db_config.db_path)
@@ -252,22 +252,22 @@ class DataSourceManager:
config = DatabaseConfig(db_path=db_path) config = DatabaseConfig(db_path=db_path)
manager = DatabaseManager(config) manager = DatabaseManager(config)
if not manager.table_exists("fact_interventions"): if not manager.table_exists("pathway_nodes"):
return SourceStatus( return SourceStatus(
source_type=DataSourceType.SQLITE, source_type=DataSourceType.SQLITE,
available=False, available=False,
configured=True, configured=True,
message="fact_interventions table not found", message="pathway_nodes table not found",
last_checked=datetime.now(), last_checked=datetime.now(),
) )
count = manager.get_table_count("fact_interventions") count = manager.get_table_count("pathway_nodes")
if count == 0: if count == 0:
return SourceStatus( return SourceStatus(
source_type=DataSourceType.SQLITE, source_type=DataSourceType.SQLITE,
available=False, available=False,
configured=True, configured=True,
message="fact_interventions table is empty", message="pathway_nodes table is empty",
last_checked=datetime.now(), last_checked=datetime.now(),
) )
@@ -275,7 +275,7 @@ class DataSourceManager:
source_type=DataSourceType.SQLITE, source_type=DataSourceType.SQLITE,
available=True, available=True,
configured=True, configured=True,
message=f"SQLite database ready ({count:,} rows)", message=f"SQLite database ready ({count:,} pathway nodes)",
last_checked=datetime.now(), last_checked=datetime.now(),
) )
except Exception as e: except Exception as e:
@@ -535,49 +535,13 @@ class DataSourceManager:
drugs: Optional[list[str]], drugs: Optional[list[str]],
directories: Optional[list[str]], directories: Optional[list[str]],
) -> Optional[DataSourceResult]: ) -> Optional[DataSourceResult]:
"""Try to get data from SQLite.""" """Try to get data from SQLite.
import time
try: Note: Raw intervention data is no longer stored in SQLite.
from data_processing.loader import SQLiteDataLoader The app now uses pre-computed pathway_nodes via load_pathway_data().
This fallback is retained for interface compatibility but always returns None.
# Determine database path """
db_path = self._sqlite_db_path logger.debug("SQLite raw data fallback skipped (fact_interventions removed)")
if db_path is None:
from data_processing.database import default_db_config
db_path = Path(default_db_config.db_path)
loader = SQLiteDataLoader(
db_path=db_path,
date_range=(start_date, end_date) if start_date and end_date else None,
trusts=trusts,
drugs=drugs,
directories=directories,
)
# Check if source is valid
is_valid, msg = loader.validate_source()
if not is_valid:
logger.debug(f"SQLite not available: {msg}")
return None
start_time = time.time()
result = loader.load()
load_time = time.time() - start_time
logger.info(f"SQLite loaded {result.row_count} rows in {load_time:.2f}s")
return DataSourceResult(
df=result.df,
source_type=DataSourceType.SQLITE,
source_detail=str(db_path),
row_count=result.row_count,
cached=False,
from_fallback=False,
load_time_seconds=load_time,
)
except Exception as e:
logger.warning(f"SQLite query failed: {e}")
return None return None
def _try_file( def _try_file(
+2 -531
View File
@@ -78,42 +78,6 @@ class DrugIndicationMatchRate:
sample_unmatched: list[str] = field(default_factory=list) # Sample patient IDs sample_unmatched: list[str] = field(default_factory=list) # Sample patient IDs
@dataclass
class DrugSnomedMapping:
"""SNOMED code mapping for a drug from ref_drug_snomed_mapping."""
snomed_code: str
snomed_description: str
search_term: str
primary_directorate: str
indication: str = ""
ta_id: str = ""
@dataclass
class DirectSnomedMatchResult:
"""Result of direct SNOMED code lookup in GP records."""
patient_pseudonym: str
matched: bool
snomed_code: Optional[str] = None
snomed_description: Optional[str] = None
search_term: Optional[str] = None
primary_directorate: Optional[str] = None
event_date: Optional[datetime] = None
source: str = "DIRECT_SNOMED" # DIRECT_SNOMED | NONE
@dataclass
class DirectorateAssignment:
"""Result of directorate assignment for a patient-drug combination."""
upid: str
drug_name: str
directorate: Optional[str]
search_term: Optional[str] = None
source: str = "FALLBACK" # DIAGNOSIS | FALLBACK
snomed_code: Optional[str] = None
event_date: Optional[datetime] = None
def get_drug_clusters( def get_drug_clusters(
drug_name: str, drug_name: str,
db_manager: Optional[DatabaseManager] = None db_manager: Optional[DatabaseManager] = None
@@ -180,266 +144,6 @@ def get_drug_cluster_ids(
return list(set(c["cluster_id"] for c in clusters)) return list(set(c["cluster_id"] for c in clusters))
def get_drug_snomed_codes(
drug_name: str,
db_manager: Optional[DatabaseManager] = None
) -> list[DrugSnomedMapping]:
"""
Get all SNOMED codes for a drug from local ref_drug_snomed_mapping table.
This uses the enriched mapping CSV data loaded into SQLite, which provides
direct SNOMED-to-drug mappings with Search_Term and PrimaryDirectorate.
Args:
drug_name: Drug name to look up (case-insensitive, matches cleaned_drug_name)
db_manager: Optional DatabaseManager (defaults to default_db_manager)
Returns:
List of DrugSnomedMapping with snomed_code, snomed_description,
search_term, primary_directorate, indication, ta_id
"""
if db_manager is None:
db_manager = default_db_manager
query = """
SELECT DISTINCT
snomed_code,
snomed_description,
search_term,
primary_directorate,
indication,
ta_id
FROM ref_drug_snomed_mapping
WHERE UPPER(cleaned_drug_name) = UPPER(?)
OR UPPER(drug_name) = UPPER(?)
ORDER BY search_term, snomed_code
"""
try:
with db_manager.get_connection() as conn:
cursor = conn.execute(query, (drug_name, drug_name))
rows = cursor.fetchall()
results = []
for row in rows:
results.append(DrugSnomedMapping(
snomed_code=row["snomed_code"],
snomed_description=row["snomed_description"] or "",
search_term=row["search_term"] or "",
primary_directorate=row["primary_directorate"] or "",
indication=row["indication"] or "",
ta_id=row["ta_id"] or "",
))
logger.debug(f"Found {len(results)} SNOMED mappings for drug '{drug_name}'")
return results
except Exception as e:
logger.error(f"Error getting SNOMED codes for drug '{drug_name}': {e}")
return []
def patient_has_indication_direct(
patient_pseudonym: str,
drug_snomed_mappings: list[DrugSnomedMapping],
connector: Optional[SnowflakeConnector] = None,
before_date: Optional[date] = None,
) -> DirectSnomedMatchResult:
"""
Check if patient has any of the SNOMED codes in their GP records.
This is the direct SNOMED lookup - it queries PrimaryCareClinicalCoding
for exact SNOMED code matches (not via cluster). Returns the most recent
match by EventDateTime if multiple matches exist.
Args:
patient_pseudonym: Patient's pseudonymised NHS number
drug_snomed_mappings: List of DrugSnomedMapping from get_drug_snomed_codes()
connector: Optional SnowflakeConnector (defaults to singleton)
before_date: Optional date - only check diagnoses before this date
Returns:
DirectSnomedMatchResult with match details (most recent by EventDateTime)
"""
result = DirectSnomedMatchResult(
patient_pseudonym=patient_pseudonym,
matched=False,
source="NONE",
)
if not drug_snomed_mappings:
return result
if not SNOWFLAKE_AVAILABLE:
logger.warning("Snowflake connector not available")
return result
if not is_snowflake_configured():
logger.warning("Snowflake not configured - cannot check GP records")
return result
if connector is None:
connector = get_connector()
# Build lookup dict for mapping snomed_code -> (search_term, primary_directorate, snomed_description)
snomed_lookup = {
m.snomed_code: (m.search_term, m.primary_directorate, m.snomed_description)
for m in drug_snomed_mappings
}
# Get unique SNOMED codes
snomed_codes = list(snomed_lookup.keys())
# Build placeholders for SNOMED codes
placeholders = ", ".join(["%s"] * len(snomed_codes))
# Query to find most recent matching SNOMED code in GP records
query = f'''
SELECT
"SNOMEDCode",
"EventDateTime"
FROM DATA_HUB.PHM."PrimaryCareClinicalCoding"
WHERE "PatientPseudonym" = %s
AND "SNOMEDCode" IN ({placeholders})
'''
params: list = [patient_pseudonym] + snomed_codes
if before_date:
query += ' AND "EventDateTime" < %s'
params.append(before_date.isoformat())
query += ' ORDER BY "EventDateTime" DESC LIMIT 1'
try:
results = connector.execute_dict(query, tuple(params))
if results:
row = results[0]
matched_code = row.get("SNOMEDCode")
event_dt = row.get("EventDateTime")
if matched_code and matched_code in snomed_lookup:
search_term, primary_dir, snomed_desc = snomed_lookup[matched_code]
return DirectSnomedMatchResult(
patient_pseudonym=patient_pseudonym,
matched=True,
snomed_code=matched_code,
snomed_description=snomed_desc,
search_term=search_term,
primary_directorate=primary_dir,
event_date=event_dt,
source="DIRECT_SNOMED",
)
return result
except Exception as e:
logger.error(f"Error checking direct SNOMED for patient '{patient_pseudonym}': {e}")
return result
def get_directorate_from_diagnosis(
upid: str,
drug_name: str,
connector: Optional[SnowflakeConnector] = None,
db_manager: Optional[DatabaseManager] = None,
before_date: Optional[date] = None,
) -> DirectorateAssignment:
"""
Get directorate assignment for a patient-drug combination using diagnosis-based lookup.
This function attempts to assign a directorate based on the patient's GP records
(direct SNOMED code matching). If no match is found, it returns a FALLBACK result
indicating that the caller should use alternative assignment methods (e.g.,
department_identification() from tools/data.py).
Workflow:
1. Get all SNOMED codes for the drug from ref_drug_snomed_mapping
2. Query patient's GP records for matching SNOMED codes
3. If match found → return diagnosis-based directorate and search_term
4. If no match → return FALLBACK result (caller handles fallback logic)
Args:
upid: Patient's unique patient ID (Provider Code[:3] + PersonKey)
drug_name: Drug name to look up
connector: Optional SnowflakeConnector (defaults to singleton)
db_manager: Optional DatabaseManager (defaults to default_db_manager)
before_date: Optional date - only check diagnoses before this date
Returns:
DirectorateAssignment with directorate, search_term, and source
"""
result = DirectorateAssignment(
upid=upid,
drug_name=drug_name,
directorate=None,
source="FALLBACK",
)
# Step 1: Get SNOMED codes for the drug
drug_snomed_mappings = get_drug_snomed_codes(drug_name, db_manager)
if not drug_snomed_mappings:
logger.debug(f"No SNOMED mappings found for drug '{drug_name}' - using fallback")
return result
# Step 2: Check Snowflake availability
if not SNOWFLAKE_AVAILABLE:
logger.debug("Snowflake not available - using fallback")
return result
if not is_snowflake_configured():
logger.debug("Snowflake not configured - using fallback")
return result
# Step 3: Get patient pseudonym from UPID
# UPID format is Provider Code (3 chars) + PersonKey
# We need to query Snowflake to get the PatientPseudonym for this PersonKey
# However, patient_has_indication_direct expects PatientPseudonym, not UPID
# For now, we'll use UPID as the identifier - the actual integration
# will need to happen at the DataFrame level where we have PersonKey
#
# NOTE: This function will be called from the pipeline where we have
# access to PatientPseudonym. The UPID is passed for logging/tracking.
# Actually, looking at the pipeline, we need PatientPseudonym, not UPID.
# The caller should pass the PatientPseudonym or we need to look it up.
# For now, let's assume the caller will use this in a batch context
# where they can map UPID -> PatientPseudonym.
# Let me reconsider: the function signature takes UPID but we need
# PatientPseudonym for Snowflake. In the pipeline context (fetch_and_transform_data),
# we'll have the PersonKey column which IS the PatientPseudonym.
# So UPID = ProviderCode[:3] + PersonKey, and PersonKey = PatientPseudonym.
#
# We can extract PatientPseudonym from UPID by removing the first 3 chars.
patient_pseudonym = upid[3:] if len(upid) > 3 else upid
# Step 4: Check patient's GP records for matching SNOMED codes
match_result = patient_has_indication_direct(
patient_pseudonym=patient_pseudonym,
drug_snomed_mappings=drug_snomed_mappings,
connector=connector,
before_date=before_date,
)
if match_result.matched and match_result.primary_directorate:
return DirectorateAssignment(
upid=upid,
drug_name=drug_name,
directorate=match_result.primary_directorate,
search_term=match_result.search_term,
source="DIAGNOSIS",
snomed_code=match_result.snomed_code,
event_date=match_result.event_date,
)
# No match found - return fallback result
return result
def get_cluster_snomed_codes( def get_cluster_snomed_codes(
cluster_id: str, cluster_id: str,
connector: Optional[SnowflakeConnector] = None, connector: Optional[SnowflakeConnector] = None,
@@ -864,229 +568,6 @@ def get_available_clusters(
return [] return []
def batch_lookup_indication_groups(
df: "pd.DataFrame",
connector: Optional[SnowflakeConnector] = None,
db_manager: Optional[DatabaseManager] = None,
batch_size: int = 500,
) -> "pd.DataFrame":
"""
Batch lookup GP diagnosis-based indication groups for a DataFrame of patients.
This is the efficient batch version of get_directorate_from_diagnosis().
Instead of querying Snowflake per patient, it batches the lookups for performance.
Strategy:
1. Get all unique (PersonKey, Drug Name) pairs from DataFrame
2. For each unique drug, get all SNOMED codes from local SQLite
3. Build batched Snowflake queries to check GP records
4. Return indication_df mapping UPID → Indication_Group
For unmatched patients, Indication_Group will be their Directory (with suffix).
Args:
df: DataFrame with columns: UPID, Drug Name, Directory, PersonKey
connector: Optional SnowflakeConnector (defaults to singleton)
db_manager: Optional DatabaseManager (defaults to default_db_manager)
batch_size: Number of patients per Snowflake query batch
Returns:
DataFrame with columns: UPID, Indication_Group, Source
- Indication_Group: Search_Term (if matched) or "Directory (no GP dx)" (if not)
- Source: "DIAGNOSIS" or "FALLBACK"
"""
import pandas as pd
if db_manager is None:
db_manager = default_db_manager
logger.info(f"Starting batch indication lookup for {len(df)} records...")
# Step 1: Get unique (UPID, Drug Name, PseudoNHSNoLinked, Directory) combinations
# We need PseudoNHSNoLinked to query Snowflake - this matches PatientPseudonym in GP records
# Note: PersonKey is LocalPatientID which is provider-specific and does NOT match GP records
if 'PseudoNHSNoLinked' not in df.columns:
logger.error("DataFrame missing 'PseudoNHSNoLinked' column - cannot lookup GP records")
# Return fallback for all patients
result_df = df[['UPID', 'Directory']].drop_duplicates().copy()
result_df['Indication_Group'] = result_df['Directory'] + " (no GP dx)"
result_df['Source'] = "FALLBACK"
return result_df[['UPID', 'Indication_Group', 'Source']]
# Get unique patient-drug combinations (we need one lookup per patient-drug pair)
unique_pairs = df[['UPID', 'Drug Name', 'PseudoNHSNoLinked', 'Directory']].drop_duplicates()
logger.info(f"Found {len(unique_pairs)} unique patient-drug combinations")
# Step 2: Get all unique drugs and their SNOMED codes
unique_drugs = unique_pairs['Drug Name'].unique()
logger.info(f"Building SNOMED lookup for {len(unique_drugs)} unique drugs...")
# Build drug -> list of DrugSnomedMapping dict
drug_snomed_map: dict[str, list[DrugSnomedMapping]] = {}
all_snomed_codes: set[str] = set()
snomed_to_drug_searchterm: dict[str, list[tuple[str, str, str]]] = {} # snomed -> [(drug, search_term, primary_dir), ...]
for drug_name in unique_drugs:
mappings = get_drug_snomed_codes(drug_name, db_manager)
drug_snomed_map[drug_name] = mappings
for m in mappings:
all_snomed_codes.add(m.snomed_code)
if m.snomed_code not in snomed_to_drug_searchterm:
snomed_to_drug_searchterm[m.snomed_code] = []
snomed_to_drug_searchterm[m.snomed_code].append(
(drug_name, m.search_term, m.primary_directorate)
)
logger.info(f"Total SNOMED codes to check: {len(all_snomed_codes)}")
# Step 3: Check Snowflake availability
if not SNOWFLAKE_AVAILABLE or not is_snowflake_configured():
logger.warning("Snowflake not available - returning fallback for all patients")
result_df = unique_pairs[['UPID', 'Directory']].copy()
result_df['Indication_Group'] = result_df['Directory'] + " (no GP dx)"
result_df['Source'] = "FALLBACK"
return result_df[['UPID', 'Indication_Group', 'Source']].drop_duplicates(subset=['UPID'])
if connector is None:
connector = get_connector()
# Step 4: Query GP records for all patients in batches
# The query finds the most recent matching SNOMED code for each patient
# Get unique PseudoNHSNoLinked values (each = one patient in GP records)
unique_patients = unique_pairs[['PseudoNHSNoLinked', 'UPID', 'Directory']].drop_duplicates(subset=['PseudoNHSNoLinked'])
patient_pseudonyms = unique_patients['PseudoNHSNoLinked'].tolist()
logger.info(f"Querying GP records for {len(patient_pseudonyms)} unique patients in batches of {batch_size}...")
# Results dict: PersonKey -> (snomed_code, event_date)
gp_matches: dict[str, tuple[str, Any]] = {}
# Convert SNOMED codes to list for query
snomed_list = list(all_snomed_codes)
if not snomed_list:
logger.warning("No SNOMED codes to check - returning fallback for all patients")
result_df = unique_pairs[['UPID', 'Directory']].copy()
result_df['Indication_Group'] = result_df['Directory'] + " (no GP dx)"
result_df['Source'] = "FALLBACK"
return result_df[['UPID', 'Indication_Group', 'Source']].drop_duplicates(subset=['UPID'])
# Build SNOMED IN clause (reused across batches)
snomed_placeholders = ", ".join(["%s"] * len(snomed_list))
# Process patients in batches
for batch_start in range(0, len(patient_pseudonyms), batch_size):
batch_end = min(batch_start + batch_size, len(patient_pseudonyms))
batch_pseudonyms = patient_pseudonyms[batch_start:batch_end]
logger.info(f"Batch {batch_start//batch_size + 1}: patients {batch_start} to {batch_end}")
# Build patient IN clause
patient_placeholders = ", ".join(["%s"] * len(batch_pseudonyms))
# Query to find all matching SNOMED codes for these patients
# We'll get all matches and pick the most recent per patient in Python
query = f'''
SELECT
"PatientPseudonym",
"SNOMEDCode",
"EventDateTime"
FROM DATA_HUB.PHM."PrimaryCareClinicalCoding"
WHERE "PatientPseudonym" IN ({patient_placeholders})
AND "SNOMEDCode" IN ({snomed_placeholders})
ORDER BY "PatientPseudonym", "EventDateTime" DESC
'''
params = tuple(batch_pseudonyms) + tuple(snomed_list)
try:
results = connector.execute_dict(query, params)
# Process results - pick most recent per patient
for row in results:
person_key = row.get("PatientPseudonym")
snomed_code = row.get("SNOMEDCode")
event_date = row.get("EventDateTime")
if person_key and snomed_code:
# Keep only if we haven't seen this patient yet (first = most recent due to ORDER BY)
if person_key not in gp_matches:
gp_matches[person_key] = (snomed_code, event_date)
except Exception as e:
logger.error(f"Error querying GP records for batch: {e}")
# Continue with other batches
logger.info(f"Found GP matches for {len(gp_matches)} patients")
# Step 5: Build result DataFrame
# For each unique_pair, determine Indication_Group based on match status
results_list = []
# We need to dedupe by UPID - a patient might be on multiple drugs
# Strategy: For each UPID, use the most recent match (if any)
upid_to_match: dict[str, tuple[str, str]] = {} # UPID -> (Indication_Group, Source)
for _, row in unique_pairs.iterrows():
upid = row['UPID']
drug_name = row['Drug Name']
patient_pseudonym = row['PseudoNHSNoLinked']
directory = row['Directory']
# Check if patient has GP match (using PseudoNHSNoLinked which matches PatientPseudonym in GP)
if patient_pseudonym in gp_matches:
matched_snomed, event_date = gp_matches[patient_pseudonym]
# Find the search_term for this SNOMED code and drug
# (A SNOMED code might map to multiple drugs with different search_terms)
if matched_snomed in snomed_to_drug_searchterm:
# Look for match with current drug first
search_term = None
for drug, st, pd in snomed_to_drug_searchterm[matched_snomed]:
if drug.upper() == drug_name.upper():
search_term = st
break
# If no drug-specific match, use any match
if search_term is None:
search_term = snomed_to_drug_searchterm[matched_snomed][0][1]
# Only update if we don't have a match for this UPID yet
if upid not in upid_to_match:
upid_to_match[upid] = (search_term, "DIAGNOSIS")
else:
# Shouldn't happen but fallback just in case
if upid not in upid_to_match:
upid_to_match[upid] = (directory + " (no GP dx)", "FALLBACK")
else:
# No GP match - use fallback
if upid not in upid_to_match:
upid_to_match[upid] = (directory + " (no GP dx)", "FALLBACK")
# Build result DataFrame
for upid, (indication_group, source) in upid_to_match.items():
results_list.append({
'UPID': upid,
'Indication_Group': indication_group,
'Source': source,
})
result_df = pd.DataFrame(results_list)
# Log statistics
diagnosis_count = len([s for s in result_df['Source'] if s == "DIAGNOSIS"])
fallback_count = len([s for s in result_df['Source'] if s == "FALLBACK"])
total = len(result_df)
logger.info(f"Indication lookup complete:")
logger.info(f" Total unique patients: {total}")
logger.info(f" DIAGNOSIS matches: {diagnosis_count} ({100*diagnosis_count/total:.1f}%)")
logger.info(f" FALLBACK (no GP match): {fallback_count} ({100*fallback_count/total:.1f}%)")
return result_df
# === Drug-to-indication mapping from DimSearchTerm.csv === # === Drug-to-indication mapping from DimSearchTerm.csv ===
@@ -1713,10 +1194,7 @@ __all__ = [
"ClusterSnomedCodes", "ClusterSnomedCodes",
"IndicationValidationResult", "IndicationValidationResult",
"DrugIndicationMatchRate", "DrugIndicationMatchRate",
"DrugSnomedMapping", # Cluster-based lookup functions
"DirectSnomedMatchResult",
"DirectorateAssignment",
# Cluster-based lookup functions (existing)
"get_drug_clusters", "get_drug_clusters",
"get_drug_cluster_ids", "get_drug_cluster_ids",
"get_cluster_snomed_codes", "get_cluster_snomed_codes",
@@ -1725,20 +1203,13 @@ __all__ = [
"get_indication_match_rate", "get_indication_match_rate",
"batch_validate_indications", "batch_validate_indications",
"get_available_clusters", "get_available_clusters",
# Direct SNOMED lookup functions (new)
"get_drug_snomed_codes",
"patient_has_indication_direct",
# Diagnosis-based directorate assignment
"get_directorate_from_diagnosis",
# Batch lookup for indication groups
"batch_lookup_indication_groups",
# Drug-indication mapping from DimSearchTerm.csv # Drug-indication mapping from DimSearchTerm.csv
"SEARCH_TERM_MERGE_MAP", "SEARCH_TERM_MERGE_MAP",
"load_drug_indication_mapping", "load_drug_indication_mapping",
"get_search_terms_for_drug", "get_search_terms_for_drug",
# Drug-aware indication assignment # Drug-aware indication assignment
"assign_drug_indications", "assign_drug_indications",
# Snowflake-direct indication lookup (new approach) # Snowflake-direct indication lookup
"get_patient_indication_groups", "get_patient_indication_groups",
"CLUSTER_MAPPING_SQL", "CLUSTER_MAPPING_SQL",
] ]
-401
View File
@@ -1,401 +0,0 @@
"""
Load enriched SNOMED mapping data into SQLite database.
This module loads the drug_snomed_mapping_enriched.csv file into the
ref_drug_snomed_mapping table for direct GP record matching.
Source file: data/drug_snomed_mapping_enriched.csv (163K rows)
Target table: ref_drug_snomed_mapping
Usage:
python -m data_processing.load_snomed_mapping
Columns mapped:
Drug -> drug_name
Indication -> indication
TA_ID -> ta_id
Search_Term -> search_term
SNOMEDCode -> snomed_code (cleaned: removes trailing .0)
SNOMEDDescription -> snomed_description
CleanedDrugName -> cleaned_drug_name
PrimaryDirectorate -> primary_directorate
AllDirectorates -> all_directorates
"""
from pathlib import Path
from typing import Optional
from core.logging_config import get_logger
from data_processing.database import DatabaseManager
from data_processing.reference_data import MigrationResult, _read_csv_with_fallback_encoding
logger = get_logger(__name__)
DEFAULT_CSV_PATH = Path("./data/drug_snomed_mapping_enriched.csv")
def clean_snomed_code(snomed_code: str) -> str:
"""
Clean SNOMED code by removing trailing .0 suffix and handling scientific notation.
The enriched CSV has SNOMED codes that may be in decimal notation (e.g., "156370009.0")
or scientific notation (e.g., "1.0629311000119108e+16") due to pandas/Excel export.
These need to be converted to clean integer strings.
Args:
snomed_code: Raw SNOMED code from CSV.
Returns:
Cleaned SNOMED code as string (e.g., "156370009" or "10629311000119108").
"""
if not snomed_code:
return ""
code = snomed_code.strip()
# Handle scientific notation (e.g., "1.0629311000119108e+16")
if 'e' in code.lower():
try:
# Convert to float first, then to int, then to string
# Using int() directly on the float preserves precision for SNOMED codes
value = float(code)
# Check if it's a whole number (no decimal part)
if value == int(value):
return str(int(value))
else:
# Has decimal part - return as cleaned float
return str(value).replace('.0', '')
except (ValueError, OverflowError):
# If conversion fails, return as-is but cleaned
return code
# Remove trailing .0 if present (for non-scientific notation)
if code.endswith(".0"):
code = code[:-2]
return code
def migrate_drug_snomed_mapping(
db_manager: Optional[DatabaseManager] = None,
csv_path: Optional[Path] = None
) -> MigrationResult:
"""
Migrate drug SNOMED mappings from CSV to SQLite ref_drug_snomed_mapping table.
Source file format (with header):
Drug,Indication,TA_ID,Search_Term,SNOMEDCode,SNOMEDDescription,
CleanedDrugName,PrimaryDirectorate,AllDirectorates
Example rows:
ABATACEPT,Psoriatic arthritis after DMARDs,TA568,psoriatic arthritis,
156370009.0,Psoriatic arthritis,ABATACEPT,RHEUMATOLOGY,RHEUMATOLOGY|DERMATOLOGY
Args:
db_manager: DatabaseManager instance. Uses default if not provided.
csv_path: Path to the CSV file. Defaults to data/drug_snomed_mapping_enriched.csv.
Returns:
MigrationResult with statistics about the migration.
"""
if db_manager is None:
db_manager = DatabaseManager()
if csv_path is None:
csv_path = DEFAULT_CSV_PATH
table_name = "ref_drug_snomed_mapping"
logger.info(f"Migrating drug SNOMED mappings from {csv_path} to {table_name}")
if not csv_path.exists():
error_msg = f"Source file not found: {csv_path}"
logger.error(error_msg)
return MigrationResult(
table_name=table_name,
source_file=str(csv_path),
rows_read=0,
rows_inserted=0,
rows_skipped=0,
success=False,
error_message=error_msg
)
rows_read = 0
rows_inserted = 0
rows_skipped = 0
try:
with db_manager.get_transaction() as conn:
rows = _read_csv_with_fallback_encoding(csv_path)
for i, row in enumerate(rows):
# Skip header row
if i == 0 and len(row) >= 5 and row[0].strip().lower() == "drug":
logger.debug("Skipping header row")
continue
rows_read += 1
# Validate row format (need at least: Drug, Indication, TA_ID, Search_Term, SNOMEDCode)
if len(row) < 5:
logger.warning(f"Skipping malformed row {rows_read}: {row}")
rows_skipped += 1
continue
drug_name = row[0].strip()
indication = row[1].strip()
ta_id = row[2].strip() if len(row) > 2 else ""
search_term = row[3].strip()
snomed_code_raw = row[4].strip() if len(row) > 4 else ""
snomed_description = row[5].strip() if len(row) > 5 else ""
cleaned_drug_name = row[6].strip() if len(row) > 6 else drug_name.upper()
primary_directorate = row[7].strip() if len(row) > 7 else ""
all_directorates = row[8].strip() if len(row) > 8 else ""
# Skip if required fields are empty
if not drug_name or not indication or not search_term or not snomed_code_raw:
logger.warning(f"Skipping row {rows_read} with empty required fields")
rows_skipped += 1
continue
# Clean SNOMED code (remove trailing .0)
snomed_code = clean_snomed_code(snomed_code_raw)
if not snomed_code:
logger.warning(f"Skipping row {rows_read} with invalid SNOMED code: {snomed_code_raw}")
rows_skipped += 1
continue
cursor = conn.execute(
"""
INSERT OR IGNORE INTO ref_drug_snomed_mapping
(drug_name, indication, ta_id, search_term, snomed_code, snomed_description,
cleaned_drug_name, primary_directorate, all_directorates)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
drug_name,
indication,
ta_id,
search_term,
snomed_code,
snomed_description,
cleaned_drug_name,
primary_directorate,
all_directorates,
)
)
if cursor.rowcount > 0:
rows_inserted += 1
else:
rows_skipped += 1
# Log progress every 10000 rows
if rows_read % 10000 == 0:
logger.info(f"Processed {rows_read} rows, inserted {rows_inserted}")
logger.info(
f"Drug SNOMED mapping migration complete: {rows_read} rows read, "
f"{rows_inserted} inserted, {rows_skipped} skipped"
)
return MigrationResult(
table_name=table_name,
source_file=str(csv_path),
rows_read=rows_read,
rows_inserted=rows_inserted,
rows_skipped=rows_skipped,
success=True
)
except Exception as e:
error_msg = f"Migration failed: {e}"
logger.error(error_msg)
return MigrationResult(
table_name=table_name,
source_file=str(csv_path),
rows_read=rows_read,
rows_inserted=0,
rows_skipped=0,
success=False,
error_message=error_msg
)
def get_drug_snomed_mapping_counts(db_manager: Optional[DatabaseManager] = None) -> dict:
"""
Get statistics about the ref_drug_snomed_mapping table.
Args:
db_manager: DatabaseManager instance. Uses default if not provided.
Returns:
Dictionary with:
- total_mappings: Total rows in table
- unique_drugs: Count of distinct drug names
- unique_search_terms: Count of distinct search terms
- unique_snomed_codes: Count of distinct SNOMED codes
- unique_indications: Count of distinct indications
"""
if db_manager is None:
db_manager = DatabaseManager()
with db_manager.get_connection() as conn:
cursor = conn.execute("SELECT COUNT(*) FROM ref_drug_snomed_mapping")
total = cursor.fetchone()[0]
cursor = conn.execute("SELECT COUNT(DISTINCT drug_name) FROM ref_drug_snomed_mapping")
unique_drugs = cursor.fetchone()[0]
cursor = conn.execute("SELECT COUNT(DISTINCT search_term) FROM ref_drug_snomed_mapping")
unique_search_terms = cursor.fetchone()[0]
cursor = conn.execute("SELECT COUNT(DISTINCT snomed_code) FROM ref_drug_snomed_mapping")
unique_snomed_codes = cursor.fetchone()[0]
cursor = conn.execute("SELECT COUNT(DISTINCT indication) FROM ref_drug_snomed_mapping")
unique_indications = cursor.fetchone()[0]
return {
"total_mappings": total,
"unique_drugs": unique_drugs,
"unique_search_terms": unique_search_terms,
"unique_snomed_codes": unique_snomed_codes,
"unique_indications": unique_indications,
}
def verify_drug_snomed_mapping_migration(
db_manager: Optional[DatabaseManager] = None,
csv_path: Optional[Path] = None
) -> tuple[bool, str]:
"""
Verify that drug SNOMED mappings were migrated correctly.
Checks:
- Row count is reasonable (163K+ expected)
- Unique search terms is reasonable (187 expected)
- Sample lookups return expected values
Args:
db_manager: DatabaseManager instance. Uses default if not provided.
csv_path: Path to the CSV file. Defaults to data/drug_snomed_mapping_enriched.csv.
Returns:
Tuple of (success: bool, message: str)
"""
if db_manager is None:
db_manager = DatabaseManager()
if csv_path is None:
csv_path = DEFAULT_CSV_PATH
stats = get_drug_snomed_mapping_counts(db_manager)
# Basic sanity checks
if stats["total_mappings"] < 100000:
return False, f"Too few rows: expected 163K+, got {stats['total_mappings']}"
if stats["unique_search_terms"] < 100:
return False, f"Too few search terms: expected ~187, got {stats['unique_search_terms']}"
# Sample lookup verification
with db_manager.get_connection() as conn:
# Check that ABATACEPT exists (from sample data)
cursor = conn.execute(
"SELECT COUNT(*) FROM ref_drug_snomed_mapping WHERE drug_name = 'ABATACEPT'"
)
abatacept_count = cursor.fetchone()[0]
if abatacept_count == 0:
return False, "Sample drug ABATACEPT not found in table"
# Check that SNOMED codes were cleaned (no .0 suffix)
cursor = conn.execute(
"SELECT COUNT(*) FROM ref_drug_snomed_mapping WHERE snomed_code LIKE '%.0'"
)
dirty_codes = cursor.fetchone()[0]
if dirty_codes > 0:
return False, f"Found {dirty_codes} SNOMED codes with uncleaned .0 suffix"
return True, (
f"Verified {stats['total_mappings']:,} mappings: "
f"{stats['unique_drugs']} drugs, "
f"{stats['unique_search_terms']} search terms, "
f"{stats['unique_snomed_codes']:,} SNOMED codes"
)
def main():
"""CLI entry point for loading SNOMED mapping data."""
import argparse
parser = argparse.ArgumentParser(
description="Load drug SNOMED mapping data into SQLite database"
)
parser.add_argument(
"--csv",
type=Path,
default=DEFAULT_CSV_PATH,
help=f"Path to CSV file (default: {DEFAULT_CSV_PATH})"
)
parser.add_argument(
"--verify-only",
action="store_true",
help="Only verify existing data, don't migrate"
)
parser.add_argument(
"-v", "--verbose",
action="store_true",
help="Enable verbose logging"
)
args = parser.parse_args()
# Configure logging
import logging
if args.verbose:
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.INFO)
if args.verify_only:
print("Verifying existing data...")
success, message = verify_drug_snomed_mapping_migration(csv_path=args.csv)
if success:
print(f"[OK] Verification passed: {message}")
else:
print(f"[FAILED] Verification failed: {message}")
return 0 if success else 1
# Run migration
print(f"Loading SNOMED mapping from {args.csv}...")
result = migrate_drug_snomed_mapping(csv_path=args.csv)
if result.success:
print(f"[OK] {result}")
# Show statistics
stats = get_drug_snomed_mapping_counts()
print(f"\nTable statistics:")
print(f" Total mappings: {stats['total_mappings']:,}")
print(f" Unique drugs: {stats['unique_drugs']}")
print(f" Unique search terms: {stats['unique_search_terms']}")
print(f" Unique SNOMED codes: {stats['unique_snomed_codes']:,}")
print(f" Unique indications: {stats['unique_indications']}")
# Verify
success, message = verify_drug_snomed_mapping_migration(csv_path=args.csv)
if success:
print(f"\n[OK] Verification: {message}")
else:
print(f"\n[WARNING] Verification: {message}")
return 1
else:
print(f"[FAILED] {result}")
return 1
return 0
if __name__ == "__main__":
exit(main())
+2 -155
View File
@@ -11,7 +11,6 @@ The DataLoader ABC defines the contract for all loader implementations.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import date
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@@ -29,7 +28,7 @@ class LoadResult:
Attributes: Attributes:
df: The loaded DataFrame with processed patient intervention data df: The loaded DataFrame with processed patient intervention data
source: Description of the data source (e.g., "csv:/path/to/file.csv", "sqlite:fact_interventions") source: Description of the data source (e.g., "file:/path/to/file.csv")
row_count: Number of rows loaded row_count: Number of rows loaded
columns: List of column names in the DataFrame columns: List of column names in the DataFrame
load_time_seconds: Time taken to load the data load_time_seconds: Time taken to load the data
@@ -224,150 +223,6 @@ class FileDataLoader(DataLoader):
) )
class SQLiteDataLoader(DataLoader):
"""Loads data from SQLite fact_interventions table.
This provides faster loading by reading pre-processed data from SQLite
instead of re-processing CSV files each time.
The SQLite database must have been populated by the migration scripts.
Args:
db_path: Path to the SQLite database (uses default if None)
date_range: Optional tuple of (start_date, end_date) to filter data
trusts: Optional list of trust names to filter
drugs: Optional list of drug names to filter
directories: Optional list of directories to filter
"""
def __init__(
self,
db_path: Optional[Path | str] = None,
date_range: Optional[tuple[date, date]] = None,
trusts: Optional[list[str]] = None,
drugs: Optional[list[str]] = None,
directories: Optional[list[str]] = None,
):
from data_processing.database import default_db_config
self.db_path = Path(db_path) if db_path else Path(default_db_config.db_path)
self.date_range = date_range
self.trusts = trusts
self.drugs = drugs
self.directories = directories
def validate_source(self) -> tuple[bool, str]:
"""Check if the database exists and has the fact_interventions table."""
if not self.db_path.exists():
return False, f"Database not found: {self.db_path}"
# Check if fact_interventions table exists
from data_processing.database import DatabaseManager, DatabaseConfig
config = DatabaseConfig(db_path=self.db_path)
manager = DatabaseManager(config)
if not manager.table_exists("fact_interventions"):
return False, "fact_interventions table not found in database"
count = manager.get_table_count("fact_interventions")
if count == 0:
return False, "fact_interventions table is empty"
return True, f"OK ({count:,} rows available)"
@property
def source_description(self) -> str:
return f"sqlite:{self.db_path}"
def load(self) -> LoadResult:
"""Load data from SQLite fact_interventions table.
Maps SQLite column names to the expected DataFrame column names.
Applies optional filters for date range, trusts, drugs, directories.
"""
import time
from data_processing.database import DatabaseManager, DatabaseConfig
start_time = time.time()
# Validate source
is_valid, msg = self.validate_source()
if not is_valid:
raise FileNotFoundError(msg)
logger.info(f"Loading data from SQLite: {self.db_path}")
# Build query with optional filters
query = """
SELECT
upid AS "UPID",
provider_code AS "Provider Code",
person_key AS "PersonKey",
drug_name_std AS "Drug Name",
intervention_date AS "Intervention Date",
price_actual AS "Price Actual",
org_name AS "OrganisationName",
directory AS "Directory",
treatment_function_code AS "Treatment Function Code",
additional_detail_1 AS "Additional Detail 1",
additional_detail_2 AS "Additional Detail 2",
additional_detail_3 AS "Additional Detail 3",
additional_detail_4 AS "Additional Detail 4",
additional_detail_5 AS "Additional Detail 5"
FROM fact_interventions
WHERE 1=1
"""
params = []
if self.date_range:
start, end = self.date_range
query += " AND intervention_date >= ? AND intervention_date < ?"
params.extend([str(start), str(end)])
if self.trusts:
placeholders = ','.join('?' * len(self.trusts))
query += f" AND org_name IN ({placeholders})"
params.extend(self.trusts)
if self.drugs:
placeholders = ','.join('?' * len(self.drugs))
query += f" AND drug_name_std IN ({placeholders})"
params.extend(self.drugs)
if self.directories:
placeholders = ','.join('?' * len(self.directories))
query += f" AND directory IN ({placeholders})"
params.extend(self.directories)
# Execute query
config = DatabaseConfig(db_path=self.db_path)
manager = DatabaseManager(config)
with manager.get_connection() as conn:
df = pd.read_sql_query(query, conn, params=params)
# Convert intervention_date to datetime
df['Intervention Date'] = pd.to_datetime(df['Intervention Date'])
logger.info(f"Loaded {len(df)} rows from SQLite")
# Validate result
is_valid, missing = self.validate_dataframe(df)
if not is_valid:
raise ValueError(f"SQLite data missing required columns: {missing}")
load_time = time.time() - start_time
logger.info(f"SQLite data loading complete. {len(df)} rows in {load_time:.2f}s")
return LoadResult(
df=df,
source=self.source_description,
row_count=len(df),
load_time_seconds=load_time,
)
def get_loader( def get_loader(
source: str | Path, source: str | Path,
paths: Optional[PathConfig] = None, paths: Optional[PathConfig] = None,
@@ -376,7 +231,7 @@ def get_loader(
"""Factory function to create the appropriate DataLoader. """Factory function to create the appropriate DataLoader.
Args: Args:
source: Either a file path (CSV/Parquet) or "sqlite" for database source: File path (CSV/Parquet)
paths: PathConfig for reference data (used by FileDataLoader) paths: PathConfig for reference data (used by FileDataLoader)
**kwargs: Additional arguments passed to the loader constructor **kwargs: Additional arguments passed to the loader constructor
@@ -386,14 +241,6 @@ def get_loader(
Examples: Examples:
>>> loader = get_loader("data/activity.csv") >>> loader = get_loader("data/activity.csv")
>>> loader = get_loader("data/activity.parquet") >>> loader = get_loader("data/activity.parquet")
>>> loader = get_loader("sqlite")
>>> loader = get_loader("sqlite", date_range=(date(2024, 1, 1), date(2024, 12, 31)))
""" """
source_str = str(source).lower()
if source_str == "sqlite":
return SQLiteDataLoader(**kwargs)
# Assume it's a file path
path = Path(source) path = Path(source)
return FileDataLoader(file_path=path, paths=paths) return FileDataLoader(file_path=path, paths=paths)
+11 -155
View File
@@ -35,6 +35,7 @@ from data_processing.schema import (
verify_all_tables_exist, verify_all_tables_exist,
get_all_table_counts, get_all_table_counts,
migrate_pathway_nodes_chart_type, migrate_pathway_nodes_chart_type,
migrate_refresh_log_source_row_count,
) )
from data_processing.reference_data import ( from data_processing.reference_data import (
MigrationResult, MigrationResult,
@@ -49,12 +50,6 @@ from data_processing.reference_data import (
verify_drug_directory_map_migration, verify_drug_directory_map_migration,
verify_drug_indication_clusters_migration, verify_drug_indication_clusters_migration,
) )
from data_processing.patient_data import (
load_patient_data,
refresh_patient_treatment_summary,
get_patient_data_stats,
verify_mv_consistency,
)
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -67,9 +62,8 @@ def initialize_database(
""" """
Initialize the database with all required tables. Initialize the database with all required tables.
Creates all tables defined in the schema (reference tables, fact tables, Creates all tables defined in the schema (reference tables and pathway
materialized views, and file tracking tables). Uses IF NOT EXISTS so tables). Uses IF NOT EXISTS so safe to run multiple times.
safe to run multiple times.
Args: Args:
db_manager: DatabaseManager instance. Uses default if not provided. db_manager: DatabaseManager instance. Uses default if not provided.
@@ -122,6 +116,14 @@ def initialize_database(
else: else:
logger.error(f"pathway_nodes migration failed: {msg}") logger.error(f"pathway_nodes migration failed: {msg}")
return False return False
# Add source_row_count column to pathway_refresh_log if it doesn't exist
success, msg = migrate_refresh_log_source_row_count(conn)
if success:
logger.info(f"pathway_refresh_log migration: {msg}")
else:
logger.error(f"pathway_refresh_log migration failed: {msg}")
return False
except Exception as e: except Exception as e:
logger.error(f"Migration failed: {e}") logger.error(f"Migration failed: {e}")
return False return False
@@ -274,107 +276,6 @@ def create_progress_reporter(description: str = "Loading", width: int = 40):
return report_progress return report_progress
def load_patient_data_cli(
file_path: Path,
db_manager: Optional[DatabaseManager] = None,
paths: Optional[PathConfig] = None,
force: bool = False,
refresh_mv: bool = True
) -> bool:
"""
Load patient data from file with CLI progress reporting.
Args:
file_path: Path to CSV or Parquet file.
db_manager: DatabaseManager instance. Uses default if not provided.
paths: PathConfig for reference data. Uses default if not provided.
force: If True, re-process even if file hash matches.
refresh_mv: If True, refresh the materialized view after loading.
Returns:
True if loading succeeded, False otherwise.
"""
if db_manager is None:
db_manager = DatabaseManager()
if paths is None:
paths = default_paths
print(f"\n=== Loading Patient Data ===\n")
print(f"File: {file_path}")
# Check file exists
if not file_path.exists():
print(f"ERROR: File not found: {file_path}")
return False
# Calculate and display file info
file_size_mb = file_path.stat().st_size / (1024 * 1024)
print(f"Size: {file_size_mb:.1f} MB")
print()
# Create progress callback
progress_callback = create_progress_reporter("Loading rows", width=40)
# Load the data
result = load_patient_data(
file_path=file_path,
db_manager=db_manager,
paths=paths,
batch_size=5000,
force=force,
progress_callback=progress_callback
)
# Print result
print()
if result.was_already_processed:
print("File already processed (same hash). Skipping.")
print(f"Use --force to re-process.")
elif result.success:
print(f"Loaded {result.rows_inserted:,} rows in {result.load_time_seconds:.1f}s")
if result.rows_skipped > 0:
print(f"Skipped {result.rows_skipped:,} rows (missing UPID or date)")
else:
print(f"FAILED: {result.error_message}")
return False
# Refresh materialized view if requested
if refresh_mv and result.success and not result.was_already_processed:
print()
print("Refreshing materialized view...")
mv_progress = create_progress_reporter("Processing patients", width=40)
mv_result = refresh_patient_treatment_summary(
db_manager=db_manager,
progress_callback=mv_progress
)
if mv_result.success:
print(f"MV refreshed: {mv_result.patients_processed:,} patients in {mv_result.refresh_time_seconds:.1f}s")
# Verify consistency
consistent, msg = verify_mv_consistency(db_manager)
if consistent:
print(f"MV verification: OK")
else:
print(f"MV verification: FAILED - {msg}")
else:
print(f"MV refresh FAILED: {mv_result.error_message}")
# Print summary statistics
print()
print("=== Patient Data Summary ===")
stats = get_patient_data_stats(db_manager)
print(f" Total rows: {stats['total_rows']:,}")
print(f" Unique patients: {stats['unique_patients']:,}")
print(f" Unique drugs: {stats['unique_drugs']:,}")
print(f" Unique organizations: {stats['unique_organizations']:,}")
if stats['date_range'][0] and stats['date_range'][1]:
print(f" Date range: {stats['date_range'][0]} to {stats['date_range'][1]}")
print()
return result.success
def get_database_status(db_manager: Optional[DatabaseManager] = None) -> dict: def get_database_status(db_manager: Optional[DatabaseManager] = None) -> dict:
""" """
Get the current status of the database. Get the current status of the database.
@@ -452,8 +353,6 @@ Examples:
python -m data_processing.migrate --drop-existing # Reset database python -m data_processing.migrate --drop-existing # Reset database
python -m data_processing.migrate --reference-data # Migrate reference data python -m data_processing.migrate --reference-data # Migrate reference data
python -m data_processing.migrate --reference-data --verify # With verification python -m data_processing.migrate --reference-data --verify # With verification
python -m data_processing.migrate --load-patient-data data.parquet # Load patient data
python -m data_processing.migrate --load-patient-data data.csv --force # Force reload
python -m data_processing.migrate --db-path ./data/test.db # Custom path python -m data_processing.migrate --db-path ./data/test.db # Custom path
""" """
) )
@@ -493,23 +392,6 @@ Examples:
action="store_true", action="store_true",
help="Enable verbose logging" help="Enable verbose logging"
) )
parser.add_argument(
"--load-patient-data",
type=Path,
metavar="FILE",
help="Load patient data from CSV or Parquet file with progress reporting"
)
parser.add_argument(
"--force",
action="store_true",
help="Force re-processing even if file hash matches (use with --load-patient-data)"
)
parser.add_argument(
"--no-refresh-mv",
action="store_true",
help="Skip materialized view refresh after loading (use with --load-patient-data)"
)
args = parser.parse_args() args = parser.parse_args()
# Set up logging # Set up logging
@@ -562,32 +444,6 @@ Examples:
print("Reference data migration completed with errors. Check logs for details.") print("Reference data migration completed with errors. Check logs for details.")
return 1 return 1
# Handle --load-patient-data (load patient data from CSV/Parquet)
if args.load_patient_data:
# Ensure database exists with tables first
if not db_manager.exists:
print("Database does not exist. Initializing schema first...")
success = initialize_database(db_manager=db_manager)
if not success:
print("\nDatabase initialization failed. Check logs for details.")
return 1
# Load patient data with progress reporting
success = load_patient_data_cli(
file_path=args.load_patient_data,
db_manager=db_manager,
paths=default_paths,
force=args.force,
refresh_mv=not args.no_refresh_mv
)
if success:
print("Patient data load completed successfully.")
return 0
else:
print("Patient data load failed. Check logs for details.")
return 1
# Run schema migration (default behavior) # Run schema migration (default behavior)
success = initialize_database( success = initialize_database(
db_manager=db_manager, db_manager=db_manager,
-890
View File
@@ -1,890 +0,0 @@
"""
Patient data migration functions for NHS High-Cost Drug Patient Pathway Analysis Tool.
Provides functions to load patient intervention data from CSV/Parquet files
into the SQLite fact_interventions table. Supports:
- Batch processing for large files
- File hash tracking for incremental updates
- Progress reporting during loading
"""
import hashlib
import os
import sqlite3
import time
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Callable, Optional
import pandas as pd
from core import PathConfig, default_paths
from core.logging_config import get_logger
from data_processing.database import DatabaseManager
logger = get_logger(__name__)
@dataclass
class PatientDataLoadResult:
"""Results from a patient data load operation."""
file_path: str
file_hash: str
rows_read: int
rows_inserted: int
rows_skipped: int
success: bool
error_message: Optional[str] = None
load_time_seconds: float = 0.0
was_already_processed: bool = False
def __str__(self) -> str:
if self.was_already_processed:
return f"{self.file_path}: Already processed (same hash)"
elif self.success:
return (
f"{self.file_path}: Loaded {self.rows_inserted:,} rows "
f"in {self.load_time_seconds:.1f}s"
)
else:
return f"{self.file_path}: FAILED - {self.error_message}"
def calculate_file_hash(file_path: Path) -> str:
"""
Calculate SHA256 hash of a file.
Uses chunked reading to handle large files efficiently.
Args:
file_path: Path to the file.
Returns:
Hex string of SHA256 hash.
"""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
sha256_hash.update(chunk)
return sha256_hash.hexdigest()
def check_file_processed(
conn: sqlite3.Connection,
file_path: str,
file_hash: str
) -> tuple[bool, Optional[str]]:
"""
Check if a file has already been processed with the same hash.
Args:
conn: Database connection.
file_path: Full path to the file.
file_hash: SHA256 hash of the file.
Returns:
Tuple of (is_processed, old_hash).
- If is_processed is True and old_hash == file_hash, file is unchanged.
- If is_processed is True and old_hash != file_hash, file has changed.
- If is_processed is False, file is new.
"""
cursor = conn.execute(
"SELECT file_hash, status FROM processed_files WHERE file_path = ?",
(file_path,)
)
result = cursor.fetchone()
if result is None:
return False, None
old_hash = result["file_hash"]
status = result["status"]
# Only consider it processed if status is success and hash matches
if status == "success" and old_hash == file_hash:
return True, old_hash
return False, old_hash
def record_file_processing_start(
conn: sqlite3.Connection,
file_path: str,
file_hash: str,
file_size: int,
file_modified: datetime
) -> None:
"""
Record that we're starting to process a file.
Args:
conn: Database connection.
file_path: Full path to the file.
file_hash: SHA256 hash of the file.
file_size: File size in bytes.
file_modified: File modification timestamp.
"""
file_name = Path(file_path).name
now = datetime.now().isoformat()
conn.execute("""
INSERT INTO processed_files (
file_path, file_name, file_hash, file_size_bytes,
file_modified_at, status, first_processed_at, last_processed_at
) VALUES (?, ?, ?, ?, ?, 'processing', ?, ?)
ON CONFLICT(file_path) DO UPDATE SET
file_hash = excluded.file_hash,
file_size_bytes = excluded.file_size_bytes,
file_modified_at = excluded.file_modified_at,
status = 'processing',
last_processed_at = excluded.last_processed_at,
error_message = NULL
""", (file_path, file_name, file_hash, file_size, file_modified.isoformat(), now, now))
def record_file_processing_complete(
conn: sqlite3.Connection,
file_path: str,
row_count: int,
duration_seconds: float,
success: bool,
error_message: Optional[str] = None
) -> None:
"""
Record that file processing has completed.
Args:
conn: Database connection.
file_path: Full path to the file.
row_count: Number of rows processed.
duration_seconds: Time taken to process.
success: Whether processing was successful.
error_message: Error message if failed.
"""
status = "success" if success else "error"
conn.execute("""
UPDATE processed_files
SET status = ?,
row_count = ?,
processing_duration_seconds = ?,
error_message = ?,
last_processed_at = ?
WHERE file_path = ?
""", (status, row_count, duration_seconds, error_message, datetime.now().isoformat(), file_path))
def load_dataframe_to_sqlite(
df: pd.DataFrame,
conn: sqlite3.Connection,
source_file: str,
batch_size: int = 5000,
progress_callback: Optional[Callable[[int, int], None]] = None
) -> int:
"""
Load a processed DataFrame into fact_interventions table.
Args:
df: Processed DataFrame with required columns (from FileDataLoader).
conn: Database connection.
source_file: Source file path for tracking.
batch_size: Number of rows to insert per batch.
progress_callback: Optional callback(rows_inserted, total_rows) for progress updates.
Returns:
Number of rows inserted.
"""
# Store the original drug names before processing (for rows where mapping doesn't exist)
# The drug_names() transformation sets Drug Name to NULL when no mapping exists.
# We need to preserve the original for those cases.
# Insert SQL columns - always include drug_name_raw
insert_columns = [
"upid", "provider_code", "person_key",
"drug_name_raw", "drug_name_std",
"intervention_date", "price_actual",
"org_name", "directory",
"treatment_function_code",
"additional_detail_1", "additional_detail_2", "additional_detail_3",
"additional_detail_4", "additional_detail_5",
"source_file"
]
placeholders = ",".join(["?"] * len(insert_columns))
insert_sql = f"""
INSERT INTO fact_interventions ({",".join(insert_columns)})
VALUES ({placeholders})
"""
rows_inserted = 0
rows_skipped = 0
total_rows = len(df)
# Process in batches
for batch_start in range(0, total_rows, batch_size):
batch_end = min(batch_start + batch_size, total_rows)
batch_df = df.iloc[batch_start:batch_end]
# Prepare batch data
batch_data = []
for _, row in batch_df.iterrows():
# Skip rows missing required fields
if pd.isna(row.get("UPID")) or pd.isna(row.get("Intervention Date")):
rows_skipped += 1
continue
# Get drug names - raw and standardized
drug_name_raw = row.get("Drug Name Raw") if "Drug Name Raw" in df.columns else None
drug_name_std = row.get("Drug Name")
# If drug_name_std is NULL, use the raw drug name (uppercase)
# This handles cases where the drug isn't in the drugnames.csv mapping
if pd.isna(drug_name_std):
if drug_name_raw is not None and not pd.isna(drug_name_raw):
drug_name_std = str(drug_name_raw).upper().strip()
else:
drug_name_std = "UNKNOWN"
# Also clean up raw drug name for storage
if drug_name_raw is not None and not pd.isna(drug_name_raw):
drug_name_raw = str(drug_name_raw).strip()
# Get other values with null handling
def get_value(col_name):
if col_name not in df.columns:
return None
val = row[col_name]
if pd.isna(val):
return None
elif hasattr(val, "strftime"):
return val.strftime("%Y-%m-%d")
return val
row_data = (
get_value("UPID"),
get_value("Provider Code"),
get_value("PersonKey"),
drug_name_raw,
drug_name_std,
get_value("Intervention Date"),
get_value("Price Actual") or 0,
get_value("OrganisationName"),
get_value("Directory"),
get_value("Treatment Function Code"),
get_value("Additional Detail 1"),
get_value("Additional Detail 2"),
get_value("Additional Detail 3"),
get_value("Additional Detail 4"),
get_value("Additional Detail 5"),
source_file
)
batch_data.append(row_data)
# Execute batch insert
conn.executemany(insert_sql, batch_data)
rows_inserted += len(batch_data)
# Report progress
if progress_callback:
progress_callback(rows_inserted, total_rows)
if rows_skipped > 0:
logger.info(f"Skipped {rows_skipped:,} rows with missing UPID or Intervention Date")
return rows_inserted
def delete_file_data(conn: sqlite3.Connection, source_file: str) -> int:
"""
Delete all data from a specific source file.
Used when re-processing a changed file.
Args:
conn: Database connection.
source_file: Source file path.
Returns:
Number of rows deleted.
"""
cursor = conn.execute(
"DELETE FROM fact_interventions WHERE source_file = ?",
(source_file,)
)
return cursor.rowcount
def load_patient_data(
file_path: Path | str,
db_manager: Optional[DatabaseManager] = None,
paths: Optional[PathConfig] = None,
batch_size: int = 5000,
force: bool = False,
progress_callback: Optional[Callable[[int, int], None]] = None
) -> PatientDataLoadResult:
"""
Load patient data from CSV/Parquet file into fact_interventions table.
This is the main entry point for loading patient data. It:
1. Calculates file hash to detect changes
2. Checks if file was already processed (skip if unchanged)
3. Loads and transforms data using FileDataLoader
4. Inserts data into SQLite in batches
5. Records processing status in processed_files table
Args:
file_path: Path to CSV or Parquet file.
db_manager: DatabaseManager instance. Uses default if not provided.
paths: PathConfig for reference data. Uses default if not provided.
batch_size: Number of rows to insert per batch (default: 5000).
force: If True, re-process even if file hash matches.
progress_callback: Optional callback(rows_inserted, total_rows) for progress.
Returns:
PatientDataLoadResult with loading statistics.
"""
if db_manager is None:
db_manager = DatabaseManager()
if paths is None:
paths = default_paths
file_path = Path(file_path)
file_path_str = str(file_path.absolute())
logger.info(f"Starting patient data load from {file_path}")
start_time = time.time()
# Check file exists
if not file_path.exists():
error_msg = f"File not found: {file_path}"
logger.error(error_msg)
return PatientDataLoadResult(
file_path=file_path_str,
file_hash="",
rows_read=0,
rows_inserted=0,
rows_skipped=0,
success=False,
error_message=error_msg
)
# Calculate file hash
logger.info("Calculating file hash...")
file_hash = calculate_file_hash(file_path)
file_size = file_path.stat().st_size
file_modified = datetime.fromtimestamp(file_path.stat().st_mtime)
logger.info(f"File hash: {file_hash[:16]}... Size: {file_size:,} bytes")
# Check if already processed
if not force:
with db_manager.get_connection() as conn:
is_processed, old_hash = check_file_processed(conn, file_path_str, file_hash)
if is_processed:
logger.info(f"File already processed with same hash, skipping")
return PatientDataLoadResult(
file_path=file_path_str,
file_hash=file_hash,
rows_read=0,
rows_inserted=0,
rows_skipped=0,
success=True,
was_already_processed=True
)
elif old_hash is not None:
logger.info(f"File hash changed, will re-process (old: {old_hash[:16]}...)")
try:
# Use FileDataLoader to load and transform data
from data_processing.loader import FileDataLoader
loader = FileDataLoader(file_path, paths)
logger.info("Loading and transforming data...")
result = loader.load()
df = result.df
rows_read = result.row_count
logger.info(f"Loaded {rows_read:,} rows, starting SQLite insert...")
# Load into SQLite
with db_manager.get_transaction() as conn:
# Record that we're starting
record_file_processing_start(conn, file_path_str, file_hash, file_size, file_modified)
# Delete any existing data from this file (for re-processing)
deleted = delete_file_data(conn, file_path_str)
if deleted > 0:
logger.info(f"Deleted {deleted:,} existing rows from previous load")
# Insert new data
rows_inserted = load_dataframe_to_sqlite(
df, conn, file_path_str, batch_size, progress_callback
)
# Record success
load_time = time.time() - start_time
record_file_processing_complete(
conn, file_path_str, rows_inserted, load_time, True
)
logger.info(f"Successfully loaded {rows_inserted:,} rows in {load_time:.1f}s")
return PatientDataLoadResult(
file_path=file_path_str,
file_hash=file_hash,
rows_read=rows_read,
rows_inserted=rows_inserted,
rows_skipped=rows_read - rows_inserted,
success=True,
load_time_seconds=load_time
)
except Exception as e:
load_time = time.time() - start_time
error_msg = str(e)
logger.error(f"Failed to load patient data: {error_msg}")
# Record failure
try:
with db_manager.get_connection() as conn:
record_file_processing_complete(
conn, file_path_str, 0, load_time, False, error_msg
)
except Exception:
pass # Don't fail on failure to record failure
return PatientDataLoadResult(
file_path=file_path_str,
file_hash=file_hash if 'file_hash' in dir() else "",
rows_read=0,
rows_inserted=0,
rows_skipped=0,
success=False,
error_message=error_msg,
load_time_seconds=load_time
)
def get_patient_data_stats(db_manager: Optional[DatabaseManager] = None) -> dict:
"""
Get statistics about patient data in fact_interventions.
Returns:
Dictionary with statistics about the loaded data.
"""
if db_manager is None:
db_manager = DatabaseManager()
stats = {}
with db_manager.get_connection() as conn:
# Total rows
cursor = conn.execute("SELECT COUNT(*) FROM fact_interventions")
stats["total_rows"] = cursor.fetchone()[0]
# Unique patients
cursor = conn.execute("SELECT COUNT(DISTINCT upid) FROM fact_interventions")
stats["unique_patients"] = cursor.fetchone()[0]
# Unique drugs
cursor = conn.execute("SELECT COUNT(DISTINCT drug_name_std) FROM fact_interventions")
stats["unique_drugs"] = cursor.fetchone()[0]
# Unique organizations
cursor = conn.execute("SELECT COUNT(DISTINCT org_name) FROM fact_interventions")
stats["unique_organizations"] = cursor.fetchone()[0]
# Date range
cursor = conn.execute("""
SELECT MIN(intervention_date), MAX(intervention_date)
FROM fact_interventions
""")
result = cursor.fetchone()
stats["date_range"] = (result[0], result[1]) if result else (None, None)
# Processed files
cursor = conn.execute("""
SELECT COUNT(*), SUM(row_count)
FROM processed_files WHERE status = 'success'
""")
result = cursor.fetchone()
stats["processed_files"] = result[0] if result else 0
stats["processed_rows"] = result[1] if result and result[1] else 0
return stats
def list_processed_files(db_manager: Optional[DatabaseManager] = None) -> list[dict]:
"""
List all processed files and their status.
Returns:
List of dictionaries with file processing information.
"""
if db_manager is None:
db_manager = DatabaseManager()
files = []
with db_manager.get_connection() as conn:
cursor = conn.execute("""
SELECT file_path, file_name, file_hash, file_size_bytes,
row_count, status, error_message,
first_processed_at, last_processed_at, processing_duration_seconds
FROM processed_files
ORDER BY last_processed_at DESC
""")
for row in cursor.fetchall():
files.append({
"file_path": row["file_path"],
"file_name": row["file_name"],
"file_hash": row["file_hash"],
"file_size_bytes": row["file_size_bytes"],
"row_count": row["row_count"],
"status": row["status"],
"error_message": row["error_message"],
"first_processed_at": row["first_processed_at"],
"last_processed_at": row["last_processed_at"],
"processing_duration_seconds": row["processing_duration_seconds"],
})
return files
# =============================================================================
# Materialized View Refresh Functions
# =============================================================================
@dataclass
class MVRefreshResult:
"""Results from refreshing the patient treatment summary materialized view."""
patients_processed: int
rows_inserted: int
refresh_time_seconds: float
success: bool
error_message: Optional[str] = None
def __str__(self) -> str:
if self.success:
return (
f"Refreshed MV: {self.patients_processed:,} patients "
f"in {self.refresh_time_seconds:.1f}s"
)
else:
return f"MV refresh FAILED: {self.error_message}"
def refresh_patient_treatment_summary(
db_manager: Optional[DatabaseManager] = None,
progress_callback: Optional[Callable[[int, int], None]] = None
) -> MVRefreshResult:
"""
Refresh the mv_patient_treatment_summary materialized view.
This computes per-patient aggregations from fact_interventions:
- First/last seen dates
- Total cost, average cost per intervention
- Intervention count, unique drug count
- Drug sequence (chronological, pipe-separated)
- Drug counts, costs, and date ranges (as JSON)
The MV is fully rebuilt (truncate and re-insert) for simplicity.
This typically takes 30-60 seconds for ~35,000 patients.
Args:
db_manager: DatabaseManager instance. Uses default if not provided.
progress_callback: Optional callback(patients_done, total_patients).
Returns:
MVRefreshResult with refresh statistics.
"""
if db_manager is None:
db_manager = DatabaseManager()
logger.info("Starting materialized view refresh...")
start_time = time.time()
try:
with db_manager.get_transaction() as conn:
# Step 1: Get total patient count for progress reporting
cursor = conn.execute("SELECT COUNT(DISTINCT upid) FROM fact_interventions")
total_patients = cursor.fetchone()[0]
logger.info(f"Processing {total_patients:,} unique patients")
if total_patients == 0:
logger.warning("No patient data in fact_interventions, MV will be empty")
return MVRefreshResult(
patients_processed=0,
rows_inserted=0,
refresh_time_seconds=time.time() - start_time,
success=True
)
# Step 2: Clear existing MV data
conn.execute("DELETE FROM mv_patient_treatment_summary")
logger.info("Cleared existing MV data")
# Step 3: Compute aggregations using SQL CTEs
# This is more efficient than processing row-by-row in Python
refresh_sql = """
WITH patient_aggs AS (
-- Basic aggregations per patient
SELECT
upid,
MIN(org_name) as org_name,
MIN(directory) as directory,
MIN(intervention_date) as first_seen_date,
MAX(intervention_date) as last_seen_date,
JULIANDAY(MAX(intervention_date)) - JULIANDAY(MIN(intervention_date)) as days_treated,
SUM(price_actual) as total_cost,
AVG(price_actual) as avg_cost_per_intervention,
COUNT(*) as intervention_count,
COUNT(DISTINCT drug_name_std) as unique_drug_count,
COUNT(*) as source_row_count
FROM fact_interventions
GROUP BY upid
),
drug_sequences AS (
-- Drug sequence per patient (chronological order, pipe-separated)
SELECT
upid,
GROUP_CONCAT(drug_name_std, '|') as drug_sequence
FROM (
SELECT DISTINCT
upid,
drug_name_std,
MIN(intervention_date) as first_date
FROM fact_interventions
GROUP BY upid, drug_name_std
ORDER BY upid, first_date
)
GROUP BY upid
),
drug_counts AS (
-- JSON object of drug counts per patient
SELECT
upid,
'{' || GROUP_CONCAT('"' || drug_name_std || '": ' || cnt, ', ') || '}' as drug_counts_json
FROM (
SELECT
upid,
drug_name_std,
COUNT(*) as cnt
FROM fact_interventions
GROUP BY upid, drug_name_std
)
GROUP BY upid
),
drug_costs AS (
-- JSON object of drug costs per patient
SELECT
upid,
'{' || GROUP_CONCAT('"' || drug_name_std || '": ' || ROUND(total_cost, 2), ', ') || '}' as drug_costs_json
FROM (
SELECT
upid,
drug_name_std,
SUM(price_actual) as total_cost
FROM fact_interventions
GROUP BY upid, drug_name_std
)
GROUP BY upid
),
drug_dates AS (
-- JSON object of drug date ranges per patient
SELECT
upid,
'{' || GROUP_CONCAT('"' || drug_name_std || '": {"first": "' || first_date || '", "last": "' || last_date || '"}', ', ') || '}' as drug_date_ranges_json
FROM (
SELECT
upid,
drug_name_std,
MIN(intervention_date) as first_date,
MAX(intervention_date) as last_date
FROM fact_interventions
GROUP BY upid, drug_name_std
)
GROUP BY upid
)
INSERT INTO mv_patient_treatment_summary (
upid, org_name, directory,
first_seen_date, last_seen_date, days_treated,
total_cost, avg_cost_per_intervention,
intervention_count, unique_drug_count,
drug_sequence, drug_counts_json, drug_costs_json, drug_date_ranges_json,
source_row_count, computed_at
)
SELECT
pa.upid,
pa.org_name,
pa.directory,
pa.first_seen_date,
pa.last_seen_date,
CAST(pa.days_treated AS INTEGER),
pa.total_cost,
pa.avg_cost_per_intervention,
pa.intervention_count,
pa.unique_drug_count,
ds.drug_sequence,
dc.drug_counts_json,
dco.drug_costs_json,
dd.drug_date_ranges_json,
pa.source_row_count,
CURRENT_TIMESTAMP
FROM patient_aggs pa
LEFT JOIN drug_sequences ds ON pa.upid = ds.upid
LEFT JOIN drug_counts dc ON pa.upid = dc.upid
LEFT JOIN drug_costs dco ON pa.upid = dco.upid
LEFT JOIN drug_dates dd ON pa.upid = dd.upid
"""
logger.info("Executing MV refresh query...")
conn.execute(refresh_sql)
# Get actual rows inserted
cursor = conn.execute("SELECT COUNT(*) FROM mv_patient_treatment_summary")
rows_inserted = cursor.fetchone()[0]
refresh_time = time.time() - start_time
logger.info(f"MV refresh complete: {rows_inserted:,} rows in {refresh_time:.1f}s")
# Report progress if callback provided
if progress_callback:
progress_callback(rows_inserted, total_patients)
return MVRefreshResult(
patients_processed=total_patients,
rows_inserted=rows_inserted,
refresh_time_seconds=refresh_time,
success=True
)
except Exception as e:
refresh_time = time.time() - start_time
error_msg = str(e)
logger.error(f"MV refresh failed: {error_msg}")
return MVRefreshResult(
patients_processed=0,
rows_inserted=0,
refresh_time_seconds=refresh_time,
success=False,
error_message=error_msg
)
def get_patient_summary_stats(db_manager: Optional[DatabaseManager] = None) -> dict:
"""
Get statistics about the patient treatment summary MV.
Returns:
Dictionary with MV statistics.
"""
if db_manager is None:
db_manager = DatabaseManager()
stats = {}
with db_manager.get_connection() as conn:
# Total rows
cursor = conn.execute("SELECT COUNT(*) FROM mv_patient_treatment_summary")
stats["total_patients"] = cursor.fetchone()[0]
if stats["total_patients"] == 0:
return stats
# Aggregated statistics
cursor = conn.execute("""
SELECT
SUM(total_cost) as total_cost_all,
AVG(total_cost) as avg_cost_per_patient,
SUM(intervention_count) as total_interventions,
AVG(intervention_count) as avg_interventions_per_patient,
AVG(unique_drug_count) as avg_drugs_per_patient,
AVG(days_treated) as avg_days_treated,
MIN(first_seen_date) as earliest_date,
MAX(last_seen_date) as latest_date,
MAX(computed_at) as last_refresh
FROM mv_patient_treatment_summary
""")
result = cursor.fetchone()
stats["total_cost"] = result[0] if result[0] else 0
stats["avg_cost_per_patient"] = result[1] if result[1] else 0
stats["total_interventions"] = result[2] if result[2] else 0
stats["avg_interventions_per_patient"] = result[3] if result[3] else 0
stats["avg_drugs_per_patient"] = result[4] if result[4] else 0
stats["avg_days_treated"] = result[5] if result[5] else 0
stats["date_range"] = (result[6], result[7])
stats["last_refresh"] = result[8]
# Unique directories in MV
cursor = conn.execute("SELECT COUNT(DISTINCT directory) FROM mv_patient_treatment_summary")
stats["unique_directories"] = cursor.fetchone()[0]
# Unique organizations in MV
cursor = conn.execute("SELECT COUNT(DISTINCT org_name) FROM mv_patient_treatment_summary")
stats["unique_organizations"] = cursor.fetchone()[0]
return stats
def verify_mv_consistency(db_manager: Optional[DatabaseManager] = None) -> tuple[bool, str]:
"""
Verify that the MV is consistent with fact_interventions.
Checks that:
- Patient counts match
- Total cost sums match
- Intervention counts match
Returns:
Tuple of (is_consistent, message).
"""
if db_manager is None:
db_manager = DatabaseManager()
with db_manager.get_connection() as conn:
# Get fact table counts
cursor = conn.execute("""
SELECT
COUNT(DISTINCT upid) as patients,
SUM(price_actual) as total_cost,
COUNT(*) as interventions
FROM fact_interventions
""")
fact_row = cursor.fetchone()
fact_patients = fact_row[0] or 0
fact_cost = fact_row[1] or 0
fact_interventions = fact_row[2] or 0
# Get MV counts
cursor = conn.execute("""
SELECT
COUNT(*) as patients,
SUM(total_cost) as total_cost,
SUM(intervention_count) as interventions
FROM mv_patient_treatment_summary
""")
mv_row = cursor.fetchone()
mv_patients = mv_row[0] or 0
mv_cost = mv_row[1] or 0
mv_interventions = mv_row[2] or 0
# Compare
issues = []
if fact_patients != mv_patients:
issues.append(f"Patient count mismatch: fact={fact_patients:,}, mv={mv_patients:,}")
if mv_interventions != fact_interventions:
issues.append(f"Intervention count mismatch: fact={fact_interventions:,}, mv={mv_interventions:,}")
# Allow small floating point differences in cost
cost_diff = abs(fact_cost - mv_cost)
if cost_diff > 0.01:
issues.append(f"Cost mismatch: fact={fact_cost:,.2f}, mv={mv_cost:,.2f}, diff={cost_diff:.2f}")
if issues:
return False, "; ".join(issues)
return True, f"MV consistent: {mv_patients:,} patients, {mv_interventions:,} interventions, £{mv_cost:,.2f} total"
+27 -438
View File
@@ -115,43 +115,6 @@ CREATE INDEX IF NOT EXISTS idx_ref_drug_indication_clusters_cluster ON ref_drug_
CREATE INDEX IF NOT EXISTS idx_ref_drug_indication_clusters_indication ON ref_drug_indication_clusters(indication); CREATE INDEX IF NOT EXISTS idx_ref_drug_indication_clusters_indication ON ref_drug_indication_clusters(indication);
""" """
REF_DRUG_SNOMED_MAPPING_SCHEMA = """
-- Direct SNOMED code mapping from drug to indication to GP diagnosis codes
-- Source: data/drug_snomed_mapping_enriched.csv (163K rows)
-- Used for direct GP record matching to assign diagnosis-based directorates
-- and to support indication-based pathway hierarchy (Trust → Search_Term → Drug → Pathway)
CREATE TABLE IF NOT EXISTS ref_drug_snomed_mapping (
id INTEGER PRIMARY KEY AUTOINCREMENT,
drug_name TEXT NOT NULL, -- Original drug name from mapping
indication TEXT NOT NULL, -- Specific indication (603 unique values)
ta_id TEXT, -- NICE TA reference (e.g., TA568)
search_term TEXT NOT NULL, -- Simplified grouping (187 unique values)
snomed_code TEXT NOT NULL, -- SNOMED CT code for GP record matching
snomed_description TEXT, -- SNOMED code description
cleaned_drug_name TEXT NOT NULL, -- Standardized drug name for matching
primary_directorate TEXT, -- Primary directorate for this indication
all_directorates TEXT, -- Pipe-separated list of valid directorates
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE(drug_name, indication, snomed_code)
);
-- Index for looking up SNOMED codes by drug name (most common access pattern)
CREATE INDEX IF NOT EXISTS idx_ref_drug_snomed_mapping_drug ON ref_drug_snomed_mapping(drug_name);
-- Index for looking up by cleaned drug name (standardized matching)
CREATE INDEX IF NOT EXISTS idx_ref_drug_snomed_mapping_cleaned ON ref_drug_snomed_mapping(cleaned_drug_name);
-- Index for looking up by SNOMED code (reverse lookup from GP record)
CREATE INDEX IF NOT EXISTS idx_ref_drug_snomed_mapping_snomed ON ref_drug_snomed_mapping(snomed_code);
-- Index for grouping by search_term (indication-based hierarchy)
CREATE INDEX IF NOT EXISTS idx_ref_drug_snomed_mapping_search_term ON ref_drug_snomed_mapping(search_term);
-- Composite index for drug + snomed code (common lookup pattern)
CREATE INDEX IF NOT EXISTS idx_ref_drug_snomed_mapping_drug_snomed
ON ref_drug_snomed_mapping(cleaned_drug_name, snomed_code);
"""
# ============================================================================= # =============================================================================
# Pathway Data Architecture Schemas # Pathway Data Architecture Schemas
@@ -278,6 +241,7 @@ CREATE TABLE IF NOT EXISTS pathway_refresh_log (
snowflake_query_date_from TEXT, -- Start date of Snowflake query snowflake_query_date_from TEXT, -- Start date of Snowflake query
snowflake_query_date_to TEXT, -- End date of Snowflake query snowflake_query_date_to TEXT, -- End date of Snowflake query
processing_duration_seconds REAL, -- How long the refresh took processing_duration_seconds REAL, -- How long the refresh took
source_row_count INTEGER, -- Number of Snowflake rows fetched
created_at TEXT DEFAULT CURRENT_TIMESTAMP created_at TEXT DEFAULT CURRENT_TIMESTAMP
); );
@@ -301,208 +265,6 @@ PATHWAY_TABLES_SCHEMA = f"""
""" """
# =============================================================================
# Fact Table Schemas
# =============================================================================
FACT_INTERVENTIONS_SCHEMA = """
-- Patient intervention records (fact table)
-- Source: HCD activity data (CSV/Parquet files or Snowflake)
-- This is the main fact table storing all patient intervention events
CREATE TABLE IF NOT EXISTS fact_interventions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
-- Patient identification
upid TEXT NOT NULL, -- Unique Patient ID (Provider Code[:3] + PersonKey)
provider_code TEXT NOT NULL, -- Original provider code (3-5 chars)
person_key TEXT NOT NULL, -- Patient key from source system
-- Intervention details
drug_name_raw TEXT, -- Original drug name from source
drug_name_std TEXT NOT NULL, -- Standardized drug name (via ref_drug_names)
intervention_date DATE NOT NULL, -- Date of intervention
price_actual REAL NOT NULL DEFAULT 0, -- Cost of intervention in GBP
-- Organization and directory
org_name TEXT, -- Organization name (cleaned, no commas)
directory TEXT, -- Medical directory/specialty (may be "Undefined")
-- Source tracking
source_file TEXT, -- Original file this record came from
loaded_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
-- Additional clinical fields (optional, used in directory fallback logic)
treatment_function_code INTEGER,
additional_detail_1 TEXT,
additional_detail_2 TEXT,
additional_detail_3 TEXT,
additional_detail_4 TEXT,
additional_detail_5 TEXT
);
-- Primary indexes for common filter patterns used in generate_graph()
-- UPID: Used for patient grouping, pathway analysis
CREATE INDEX IF NOT EXISTS idx_fact_interventions_upid ON fact_interventions(upid);
-- Drug name (standardized): Used for drug filtering
CREATE INDEX IF NOT EXISTS idx_fact_interventions_drug ON fact_interventions(drug_name_std);
-- Intervention date: Used for date range filtering (start_date, end_date, last_seen)
CREATE INDEX IF NOT EXISTS idx_fact_interventions_date ON fact_interventions(intervention_date);
-- Directory: Used for directory/specialty filtering
CREATE INDEX IF NOT EXISTS idx_fact_interventions_directory ON fact_interventions(directory);
-- Organization: Used for trust filtering (Provider Code maps to org_name)
CREATE INDEX IF NOT EXISTS idx_fact_interventions_org ON fact_interventions(org_name);
-- Composite index for common filter combination (trust + drug + directory)
CREATE INDEX IF NOT EXISTS idx_fact_interventions_composite
ON fact_interventions(org_name, drug_name_std, directory);
-- Composite index for date-based patient analysis
CREATE INDEX IF NOT EXISTS idx_fact_interventions_upid_date
ON fact_interventions(upid, intervention_date);
"""
# =============================================================================
# Materialized View Schemas (Cached Aggregations)
# =============================================================================
MV_PATIENT_TREATMENT_SUMMARY_SCHEMA = """
-- Materialized view of patient treatment summaries
-- Pre-computed aggregations per patient for faster pathway analysis
-- Refreshed when fact_interventions data changes
CREATE TABLE IF NOT EXISTS mv_patient_treatment_summary (
id INTEGER PRIMARY KEY AUTOINCREMENT,
-- Patient identification
upid TEXT NOT NULL UNIQUE, -- Unique Patient ID
-- Organization and directory (for filtering)
org_name TEXT, -- Organization name (first org seen)
directory TEXT, -- Primary directory (first directory assigned)
-- Date range
first_seen_date DATE NOT NULL, -- First intervention date
last_seen_date DATE NOT NULL, -- Last intervention date
days_treated INTEGER NOT NULL DEFAULT 0, -- Duration: last_seen - first_seen
-- Cost aggregations
total_cost REAL NOT NULL DEFAULT 0, -- Sum of all intervention costs
avg_cost_per_intervention REAL, -- Average cost per intervention
-- Treatment summary
intervention_count INTEGER NOT NULL DEFAULT 0, -- Total number of interventions
unique_drug_count INTEGER NOT NULL DEFAULT 0, -- Number of distinct drugs
-- Drug sequence (pipe-separated standardized drug names in chronological order)
-- Example: "ADALIMUMAB|ETANERCEPT|INFLIXIMAB"
drug_sequence TEXT,
-- Drug frequency counts (JSON: {"ADALIMUMAB": 5, "ETANERCEPT": 3})
-- Stores count of each drug for this patient
drug_counts_json TEXT,
-- Drug cost totals (JSON: {"ADALIMUMAB": 15000.00, "ETANERCEPT": 8000.00})
-- Stores total cost per drug for this patient
drug_costs_json TEXT,
-- Per-drug date ranges (JSON: {"ADALIMUMAB": {"first": "2023-01-01", "last": "2023-06-15"}, ...})
-- Stores first/last date for each drug
drug_date_ranges_json TEXT,
-- Metadata
computed_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
source_row_count INTEGER -- Number of fact_interventions rows used
);
-- Index for fast patient lookup
CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_upid ON mv_patient_treatment_summary(upid);
-- Indexes for common filter patterns
CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_org ON mv_patient_treatment_summary(org_name);
CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_directory ON mv_patient_treatment_summary(directory);
CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_first_seen ON mv_patient_treatment_summary(first_seen_date);
CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_last_seen ON mv_patient_treatment_summary(last_seen_date);
-- Composite index for date range filtering (common in generate_graph)
CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_date_range
ON mv_patient_treatment_summary(first_seen_date, last_seen_date);
-- Composite index for org + directory + dates (full filter pattern)
CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_filter_composite
ON mv_patient_treatment_summary(org_name, directory, first_seen_date, last_seen_date);
-- Index for drug sequence pattern matching
CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_drug_seq ON mv_patient_treatment_summary(drug_sequence);
"""
MATERIALIZED_VIEWS_SCHEMA = f"""
-- Materialized Views Schema
-- Pre-computed aggregations for performance
{MV_PATIENT_TREATMENT_SUMMARY_SCHEMA}
"""
# =============================================================================
# File Tracking Schemas (Incremental Updates)
# =============================================================================
PROCESSED_FILES_SCHEMA = """
-- Tracks processed data files for incremental updates
-- Enables detecting changed files by comparing hashes
-- Stores processing status and statistics
CREATE TABLE IF NOT EXISTS processed_files (
id INTEGER PRIMARY KEY AUTOINCREMENT,
-- File identification
file_path TEXT NOT NULL, -- Full path to the file
file_name TEXT NOT NULL, -- Just the filename (for display)
file_hash TEXT NOT NULL, -- SHA256 hash of file contents
-- File metadata
file_size_bytes INTEGER, -- Size of file in bytes
file_modified_at TIMESTAMP, -- File's last modification timestamp
-- Processing results
row_count INTEGER DEFAULT 0, -- Number of rows processed from this file
status TEXT NOT NULL DEFAULT 'pending', -- pending, processing, success, error
error_message TEXT, -- Error details if status='error'
-- Timestamps
first_processed_at TIMESTAMP, -- When first processed
last_processed_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
processing_duration_seconds REAL, -- How long processing took
-- Uniqueness: only one record per file path
-- Hash changes indicate file content changed (needs reprocessing)
UNIQUE(file_path)
);
-- Index for fast lookup by file path
CREATE INDEX IF NOT EXISTS idx_processed_files_path ON processed_files(file_path);
-- Index for finding files by status (e.g., find all pending or errored files)
CREATE INDEX IF NOT EXISTS idx_processed_files_status ON processed_files(status);
-- Index for finding files by hash (detect if same file appears at different paths)
CREATE INDEX IF NOT EXISTS idx_processed_files_hash ON processed_files(file_hash);
-- Index for finding recently processed files
CREATE INDEX IF NOT EXISTS idx_processed_files_last_processed ON processed_files(last_processed_at);
"""
FILE_TRACKING_SCHEMA = f"""
-- File Tracking Schema
-- Supports incremental data loading
{PROCESSED_FILES_SCHEMA}
"""
# ============================================================================= # =============================================================================
# Combined Schemas # Combined Schemas
# ============================================================================= # =============================================================================
@@ -520,29 +282,14 @@ REFERENCE_TABLES_SCHEMA = f"""
{REF_DRUG_DIRECTORY_MAP_SCHEMA} {REF_DRUG_DIRECTORY_MAP_SCHEMA}
{REF_DRUG_INDICATION_CLUSTERS_SCHEMA} {REF_DRUG_INDICATION_CLUSTERS_SCHEMA}
{REF_DRUG_SNOMED_MAPPING_SCHEMA}
"""
FACT_TABLES_SCHEMA = f"""
-- Fact Tables Schema
-- Contains patient intervention data
{FACT_INTERVENTIONS_SCHEMA}
""" """
ALL_TABLES_SCHEMA = f""" ALL_TABLES_SCHEMA = f"""
-- Complete Database Schema -- Complete Database Schema
-- Reference tables + Fact tables + Materialized views + File tracking + Pathway tables -- Reference tables + Pathway tables
{REFERENCE_TABLES_SCHEMA} {REFERENCE_TABLES_SCHEMA}
{FACT_TABLES_SCHEMA}
{MATERIALIZED_VIEWS_SCHEMA}
{FILE_TRACKING_SCHEMA}
{PATHWAY_TABLES_SCHEMA} {PATHWAY_TABLES_SCHEMA}
""" """
@@ -580,26 +327,10 @@ def drop_reference_tables(conn: sqlite3.Connection) -> None:
DROP TABLE IF EXISTS ref_directories; DROP TABLE IF EXISTS ref_directories;
DROP TABLE IF EXISTS ref_drug_directory_map; DROP TABLE IF EXISTS ref_drug_directory_map;
DROP TABLE IF EXISTS ref_drug_indication_clusters; DROP TABLE IF EXISTS ref_drug_indication_clusters;
DROP TABLE IF EXISTS ref_drug_snomed_mapping;
""") """)
logger.info("Reference tables dropped") logger.info("Reference tables dropped")
def create_drug_snomed_mapping_table(conn: sqlite3.Connection) -> None:
"""
Create the ref_drug_snomed_mapping table for direct SNOMED code mapping.
This table stores mappings from drugs to SNOMED codes for GP record matching,
enabling diagnosis-based directorate assignment and indication-based pathways.
Args:
conn: SQLite database connection.
"""
logger.info("Creating ref_drug_snomed_mapping table...")
conn.executescript(REF_DRUG_SNOMED_MAPPING_SCHEMA)
logger.info("ref_drug_snomed_mapping table created successfully")
def get_reference_table_counts(conn: sqlite3.Connection) -> dict[str, int]: def get_reference_table_counts(conn: sqlite3.Connection) -> dict[str, int]:
""" """
Get row counts for all reference tables. Get row counts for all reference tables.
@@ -616,7 +347,6 @@ def get_reference_table_counts(conn: sqlite3.Connection) -> dict[str, int]:
"ref_directories", "ref_directories",
"ref_drug_directory_map", "ref_drug_directory_map",
"ref_drug_indication_clusters", "ref_drug_indication_clusters",
"ref_drug_snomed_mapping",
] ]
counts = {} counts = {}
@@ -647,7 +377,6 @@ def verify_reference_tables_exist(conn: sqlite3.Connection) -> list[str]:
"ref_directories", "ref_directories",
"ref_drug_directory_map", "ref_drug_directory_map",
"ref_drug_indication_clusters", "ref_drug_indication_clusters",
"ref_drug_snomed_mapping",
] ]
missing = [] missing = []
@@ -662,164 +391,6 @@ def verify_reference_tables_exist(conn: sqlite3.Connection) -> list[str]:
return missing return missing
# =============================================================================
# Fact Table Helper Functions
# =============================================================================
def create_fact_tables(conn: sqlite3.Connection) -> None:
"""
Create all fact tables in the database (including materialized views).
Args:
conn: SQLite database connection.
"""
logger.info("Creating fact tables...")
conn.executescript(FACT_TABLES_SCHEMA)
conn.executescript(MATERIALIZED_VIEWS_SCHEMA)
logger.info("Fact tables created successfully")
def drop_fact_tables(conn: sqlite3.Connection) -> None:
"""
Drop all fact tables from the database.
Args:
conn: SQLite database connection.
Warning:
This will delete all patient intervention data. Use with caution.
"""
logger.warning("Dropping fact tables...")
conn.executescript("""
DROP TABLE IF EXISTS fact_interventions;
DROP TABLE IF EXISTS mv_patient_treatment_summary;
""")
logger.info("Fact tables dropped")
def get_fact_table_counts(conn: sqlite3.Connection) -> dict[str, int]:
"""
Get row counts for all fact tables (including materialized views).
Args:
conn: SQLite database connection.
Returns:
Dictionary mapping table name to row count.
"""
tables = ["fact_interventions", "mv_patient_treatment_summary"]
counts = {}
for table in tables:
cursor = conn.execute(f"SELECT COUNT(*) FROM {table}")
result = cursor.fetchone()
counts[table] = result[0] if result else 0
return counts
def verify_fact_tables_exist(conn: sqlite3.Connection) -> list[str]:
"""
Verify that all fact tables exist (including materialized views).
Args:
conn: SQLite database connection.
Returns:
List of missing table names. Empty list means all tables exist.
"""
required_tables = ["fact_interventions", "mv_patient_treatment_summary"]
missing = []
for table in required_tables:
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
(table,)
)
if cursor.fetchone() is None:
missing.append(table)
return missing
# =============================================================================
# File Tracking Helper Functions
# =============================================================================
def create_file_tracking_tables(conn: sqlite3.Connection) -> None:
"""
Create file tracking tables in the database.
Args:
conn: SQLite database connection.
"""
logger.info("Creating file tracking tables...")
conn.executescript(FILE_TRACKING_SCHEMA)
logger.info("File tracking tables created successfully")
def drop_file_tracking_tables(conn: sqlite3.Connection) -> None:
"""
Drop file tracking tables from the database.
Args:
conn: SQLite database connection.
Warning:
This will delete all file tracking history.
"""
logger.warning("Dropping file tracking tables...")
conn.executescript("""
DROP TABLE IF EXISTS processed_files;
""")
logger.info("File tracking tables dropped")
def get_file_tracking_counts(conn: sqlite3.Connection) -> dict[str, int]:
"""
Get row counts for file tracking tables.
Args:
conn: SQLite database connection.
Returns:
Dictionary mapping table name to row count.
"""
tables = ["processed_files"]
counts = {}
for table in tables:
cursor = conn.execute(f"SELECT COUNT(*) FROM {table}")
result = cursor.fetchone()
counts[table] = result[0] if result else 0
return counts
def verify_file_tracking_tables_exist(conn: sqlite3.Connection) -> list[str]:
"""
Verify that file tracking tables exist.
Args:
conn: SQLite database connection.
Returns:
List of missing table names. Empty list means all tables exist.
"""
required_tables = ["processed_files"]
missing = []
for table in required_tables:
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
(table,)
)
if cursor.fetchone() is None:
missing.append(table)
return missing
# ============================================================================= # =============================================================================
# Pathway Table Helper Functions # Pathway Table Helper Functions
# ============================================================================= # =============================================================================
@@ -1050,13 +621,37 @@ def migrate_pathway_nodes_chart_type(conn: sqlite3.Connection) -> tuple[bool, st
return False, f"Migration failed: {e}" return False, f"Migration failed: {e}"
def migrate_refresh_log_source_row_count(conn: sqlite3.Connection) -> tuple[bool, str]:
"""Add source_row_count column to pathway_refresh_log if it doesn't exist.
This column stores the Snowflake row count for display in the UI footer.
"""
cursor = conn.execute("PRAGMA table_info(pathway_refresh_log)")
columns = [row[1] for row in cursor.fetchall()]
if "source_row_count" in columns:
return True, "source_row_count column already exists"
logger.info("Adding source_row_count column to pathway_refresh_log...")
try:
conn.execute("""
ALTER TABLE pathway_refresh_log
ADD COLUMN source_row_count INTEGER
""")
conn.commit()
return True, "Added source_row_count column"
except Exception as e:
logger.error(f"Failed to add source_row_count column: {e}")
return False, f"Migration failed: {e}"
# ============================================================================= # =============================================================================
# Combined Helper Functions # Combined Helper Functions
# ============================================================================= # =============================================================================
def create_all_tables(conn: sqlite3.Connection) -> None: def create_all_tables(conn: sqlite3.Connection) -> None:
""" """
Create all tables (reference + fact) in the database. Create all tables (reference + pathway) in the database.
Args: Args:
conn: SQLite database connection. conn: SQLite database connection.
@@ -1078,8 +673,6 @@ def drop_all_tables(conn: sqlite3.Connection) -> None:
""" """
logger.warning("Dropping all tables...") logger.warning("Dropping all tables...")
drop_pathway_tables(conn) drop_pathway_tables(conn)
drop_file_tracking_tables(conn)
drop_fact_tables(conn)
drop_reference_tables(conn) drop_reference_tables(conn)
logger.info("All tables dropped") logger.info("All tables dropped")
@@ -1096,8 +689,6 @@ def get_all_table_counts(conn: sqlite3.Connection) -> dict[str, int]:
""" """
counts = {} counts = {}
counts.update(get_reference_table_counts(conn)) counts.update(get_reference_table_counts(conn))
counts.update(get_fact_table_counts(conn))
counts.update(get_file_tracking_counts(conn))
counts.update(get_pathway_table_counts(conn)) counts.update(get_pathway_table_counts(conn))
return counts return counts
@@ -1114,7 +705,5 @@ def verify_all_tables_exist(conn: sqlite3.Connection) -> list[str]:
""" """
missing = [] missing = []
missing.extend(verify_reference_tables_exist(conn)) missing.extend(verify_reference_tables_exist(conn))
missing.extend(verify_fact_tables_exist(conn))
missing.extend(verify_file_tracking_tables_exist(conn))
missing.extend(verify_pathway_tables_exist(conn)) missing.extend(verify_pathway_tables_exist(conn))
return missing return missing
+33 -520
View File
@@ -5,7 +5,7 @@ Single-page dashboard with reactive filtering and real-time chart updates.
Design reference: DESIGN_SYSTEM.md Design reference: DESIGN_SYSTEM.md
""" """
from datetime import datetime, timedelta from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -69,14 +69,6 @@ class AppState(rx.State):
# Data freshness tracking # Data freshness tracking
last_updated: str = "" # ISO format timestamp of last data load last_updated: str = "" # ISO format timestamp of last data load
# Raw data storage - list of dicts (Reflex-friendly)
# Each dict represents a patient record with keys like:
# UPID, Drug Name, Intervention Date, Price Actual, Directory, etc.
raw_data: list[dict[str, Any]] = []
# Latest date in dataset (detected on load, used for "to" date defaults)
latest_date_in_data: str = ""
# ========================================================================= # =========================================================================
# UI State Variables # UI State Variables
# ========================================================================= # =========================================================================
@@ -111,22 +103,7 @@ class AppState(rx.State):
{"value": "12mo", "label": "Last 12 months"}, {"value": "12mo", "label": "Last 12 months"},
] ]
# Legacy date filter state (kept for backwards compatibility, will be removed) # Available options for dropdowns (populated from data)
# Filter toggle state
initiated_filter_enabled: bool = False
last_seen_filter_enabled: bool = True
# Date filter values (ISO format strings YYYY-MM-DD)
# Initiated filter: Defaults empty (filter is OFF by default)
initiated_from_date: str = ""
initiated_to_date: str = ""
# Last Seen filter: Defaults to last 6 months (filter is ON by default)
# These will be updated on data load to use actual latest date
last_seen_from_date: str = (datetime.now() - timedelta(days=180)).strftime("%Y-%m-%d")
last_seen_to_date: str = datetime.now().strftime("%Y-%m-%d")
# Available options for dropdowns (populated from data in Phase 3)
available_drugs: list[str] = ["Drug A", "Drug B", "Drug C", "Drug D", "Drug E"] available_drugs: list[str] = ["Drug A", "Drug B", "Drug C", "Drug D", "Drug E"]
available_indications: list[str] = ["Indication 1", "Indication 2", "Indication 3"] available_indications: list[str] = ["Indication 1", "Indication 2", "Indication 3"]
available_directorates: list[str] = ["Medical", "Surgical", "Oncology", "Rheumatology"] available_directorates: list[str] = ["Medical", "Surgical", "Oncology", "Rheumatology"]
@@ -146,45 +123,7 @@ class AppState(rx.State):
indication_dropdown_open: bool = False indication_dropdown_open: bool = False
directorate_dropdown_open: bool = False directorate_dropdown_open: bool = False
# Event handlers for filter toggles # Event handlers for date filter dropdowns
def toggle_initiated_filter(self):
"""Toggle initiated date filter on/off."""
self.initiated_filter_enabled = not self.initiated_filter_enabled
if self.data_loaded:
self.apply_filters()
def toggle_last_seen_filter(self):
"""Toggle last seen date filter on/off."""
self.last_seen_filter_enabled = not self.last_seen_filter_enabled
if self.data_loaded:
self.apply_filters()
# Event handlers for date changes
def set_initiated_from(self, value: str):
"""Set initiated from date."""
self.initiated_from_date = value
if self.data_loaded:
self.apply_filters()
def set_initiated_to(self, value: str):
"""Set initiated to date."""
self.initiated_to_date = value
if self.data_loaded:
self.apply_filters()
def set_last_seen_from(self, value: str):
"""Set last seen from date."""
self.last_seen_from_date = value
if self.data_loaded:
self.apply_filters()
def set_last_seen_to(self, value: str):
"""Set last seen to date."""
self.last_seen_to_date = value
if self.data_loaded:
self.apply_filters()
# Event handlers for date filter dropdowns (new pathway_nodes approach)
def set_initiated_filter(self, value: str): def set_initiated_filter(self, value: str):
"""Set initiated filter dropdown value.""" """Set initiated filter dropdown value."""
self.selected_initiated = value self.selected_initiated = value
@@ -461,141 +400,6 @@ class AppState(rx.State):
# Filter Logic Methods # Filter Logic Methods
# ========================================================================= # =========================================================================
def apply_filters(self):
"""
Apply current filter state to data and update KPI values.
This method queries the SQLite database with the current filter settings:
- Initiated date filter: filters patients whose FIRST intervention date is within range
- Last Seen date filter: filters patients whose LAST intervention date is within range
- Drug filter: filters by selected drugs (empty = all)
- Directorate filter: filters by selected directorates (empty = all)
Note: Indication filter is not implemented at the database level since indications
are derived from drug mappings, not stored directly in fact_interventions.
Updates: unique_patients, total_drugs, total_cost, and filtered_record_count
"""
import sqlite3
db_path = Path("data/pathways.db")
if not db_path.exists():
self.error_message = "Unable to connect to database. Please ensure data has been loaded."
return
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
# Build the filter query dynamically
# We use a CTE to compute first_seen and last_seen dates per patient,
# then filter based on those dates if date filters are enabled
where_clauses = []
params = []
# Drug filter (if any drugs selected)
if self.selected_drugs:
placeholders = ",".join("?" * len(self.selected_drugs))
where_clauses.append(f"drug_name_std IN ({placeholders})")
params.extend(self.selected_drugs)
# Directorate filter (if any directorates selected)
if self.selected_directorates:
placeholders = ",".join("?" * len(self.selected_directorates))
where_clauses.append(f"directory IN ({placeholders})")
params.extend(self.selected_directorates)
# Build WHERE clause for base data filtering
base_where = ""
if where_clauses:
base_where = "WHERE " + " AND ".join(where_clauses)
# Date filter logic:
# - "Initiated" filters patients whose FIRST intervention is within the date range
# - "Last Seen" filters patients whose LAST intervention is within the date range
# We need to use a subquery to compute patient-level date ranges
having_clauses = []
having_params = []
# Initiated filter (when enabled)
if self.initiated_filter_enabled and self.initiated_from_date:
having_clauses.append("first_seen_date >= ?")
having_params.append(self.initiated_from_date)
if self.initiated_filter_enabled and self.initiated_to_date:
having_clauses.append("first_seen_date <= ?")
having_params.append(self.initiated_to_date)
# Last Seen filter (when enabled)
if self.last_seen_filter_enabled and self.last_seen_from_date:
having_clauses.append("last_seen_date >= ?")
having_params.append(self.last_seen_from_date)
if self.last_seen_filter_enabled and self.last_seen_to_date:
having_clauses.append("last_seen_date <= ?")
having_params.append(self.last_seen_to_date)
having_clause = ""
if having_clauses:
having_clause = "HAVING " + " AND ".join(having_clauses)
# Query to get filtered patient UPIDs
# This computes per-patient first/last seen dates and filters accordingly
patient_filter_query = f"""
WITH patient_dates AS (
SELECT
upid,
MIN(intervention_date) as first_seen_date,
MAX(intervention_date) as last_seen_date
FROM fact_interventions
{base_where}
GROUP BY upid
{having_clause}
)
SELECT upid FROM patient_dates
"""
# Now get KPI values for filtered patients
kpi_query = f"""
WITH filtered_patients AS (
{patient_filter_query}
)
SELECT
COUNT(DISTINCT f.upid) as unique_patients,
COUNT(DISTINCT f.drug_name_std) as unique_drugs,
COALESCE(SUM(f.price_actual), 0) as total_cost,
COUNT(*) as record_count
FROM fact_interventions f
INNER JOIN filtered_patients fp ON f.upid = fp.upid
{base_where.replace('WHERE', 'AND') if base_where else ''}
"""
# Combine all params: base params for CTE, having params, then base params again for final join
all_params = params + having_params
if where_clauses:
all_params.extend(params) # For the AND conditions in the final query
cursor.execute(kpi_query, all_params)
result = cursor.fetchone()
if result:
self.unique_patients = result[0] or 0
self.total_drugs = result[1] or 0
self.total_cost = float(result[2]) if result[2] else 0.0
# Note: filtered_record_count could be stored if needed
conn.close()
self.error_message = ""
# Update chart data with new filtered results
self.prepare_chart_data()
except sqlite3.Error as e:
self.error_message = f"Unable to filter data. Database error: {str(e)}"
except Exception as e:
self.error_message = f"An unexpected error occurred while filtering. Details: {str(e)}"
# ========================================================================= # =========================================================================
# Data Loading Methods # Data Loading Methods
# ========================================================================= # =========================================================================
@@ -604,11 +408,8 @@ class AppState(rx.State):
""" """
Load data from SQLite database on app initialization. Load data from SQLite database on app initialization.
This method: Sources available drugs/directorates from pathway_nodes and total_records
1. Connects to the SQLite database (data/pathways.db) from the latest pathway_refresh_log entry.
2. Loads available drugs, indications, directorates from actual data
3. Detects the latest date in the dataset for "to" date defaults
4. Updates total_records, last_updated, and data_loaded state
""" """
import sqlite3 import sqlite3
@@ -622,30 +423,38 @@ class AppState(rx.State):
conn = sqlite3.connect(str(db_path)) conn = sqlite3.connect(str(db_path))
cursor = conn.cursor() cursor = conn.cursor()
# Get total records # Get total source records from latest completed refresh log
cursor.execute("SELECT COUNT(*) FROM fact_interventions")
self.total_records = cursor.fetchone()[0]
if self.total_records == 0:
self.error_message = "The database is empty. No patient records found."
conn.close()
return
# Get available drugs (distinct, sorted)
cursor.execute(""" cursor.execute("""
SELECT DISTINCT drug_name_std SELECT source_row_count, completed_at
FROM fact_interventions FROM pathway_refresh_log
WHERE drug_name_std IS NOT NULL AND drug_name_std != '' WHERE status = 'completed'
ORDER BY drug_name_std ORDER BY started_at DESC
LIMIT 1
""")
refresh_row = cursor.fetchone()
if refresh_row:
self.total_records = refresh_row[0] or 0
if refresh_row[1]:
self.last_updated = refresh_row[1]
else:
self.total_records = 0
# Get available drugs from pathway_nodes (level 3 = drug nodes)
cursor.execute("""
SELECT DISTINCT labels
FROM pathway_nodes
WHERE level = 3 AND labels IS NOT NULL AND labels != ''
ORDER BY labels
""") """)
self.available_drugs = [row[0] for row in cursor.fetchall()] self.available_drugs = [row[0] for row in cursor.fetchall()]
# Get available directories (distinct, sorted) # Get available directorates from directory chart pathway_nodes (level 2)
cursor.execute(""" cursor.execute("""
SELECT DISTINCT directory SELECT DISTINCT labels
FROM fact_interventions FROM pathway_nodes
WHERE directory IS NOT NULL AND directory != '' WHERE level = 2 AND chart_type = 'directory'
ORDER BY directory AND labels IS NOT NULL AND labels != ''
ORDER BY labels
""") """)
self.available_directorates = [row[0] for row in cursor.fetchall()] self.available_directorates = [row[0] for row in cursor.fetchall()]
@@ -658,50 +467,17 @@ class AppState(rx.State):
""") """)
self.available_indications = [row[0] for row in cursor.fetchall()] self.available_indications = [row[0] for row in cursor.fetchall()]
# If no indications in reference table, use placeholder
if not self.available_indications: if not self.available_indications:
self.available_indications = ["(No indications available)"] self.available_indications = ["(No indications available)"]
# Get date range from data
cursor.execute("""
SELECT MIN(intervention_date), MAX(intervention_date)
FROM fact_interventions
""")
date_range = cursor.fetchone()
min_date, max_date = date_range
# Update latest_date_in_data and set "to" date defaults
if max_date:
self.latest_date_in_data = max_date
self.last_seen_to_date = max_date
self.initiated_to_date = max_date
# Set "from" date for last_seen filter (6 months before max_date)
max_dt = datetime.strptime(max_date, "%Y-%m-%d")
six_months_ago = max_dt - timedelta(days=180)
self.last_seen_from_date = six_months_ago.strftime("%Y-%m-%d")
# Get unique patient count for KPIs
cursor.execute("SELECT COUNT(DISTINCT upid) FROM fact_interventions")
self.unique_patients = cursor.fetchone()[0]
# Get unique drug count
self.total_drugs = len(self.available_drugs)
# Get total cost
cursor.execute("SELECT SUM(price_actual) FROM fact_interventions")
total_cost_result = cursor.fetchone()[0]
self.total_cost = float(total_cost_result) if total_cost_result else 0.0
conn.close() conn.close()
# Set data_loaded and last_updated
self.data_loaded = True self.data_loaded = True
if not self.last_updated:
self.last_updated = datetime.now().isoformat() self.last_updated = datetime.now().isoformat()
self.error_message = "" self.error_message = ""
# Load pre-computed pathway data for the default date filter # Load pre-computed pathway data for the default date filter
# This replaces apply_filters() which used dynamic calculation
self.load_pathway_data() self.load_pathway_data()
except sqlite3.Error as e: except sqlite3.Error as e:
@@ -986,269 +762,6 @@ class AppState(rx.State):
chart_data: list[dict[str, Any]] = [] chart_data: list[dict[str, Any]] = []
chart_title: str = "" chart_title: str = ""
def prepare_chart_data(self):
"""
Prepare hierarchical data for Plotly icicle chart.
This method queries the filtered patient data and transforms it into
a hierarchical structure: Root → Trust → Directory → Drug
The chart data is stored in self.chart_data as a list of dicts with:
- parents: Parent node identifier
- ids: Unique node identifier (hierarchical path)
- labels: Display label
- value: Patient count
- cost: Total cost
- colour: Color value (proportion of parent)
Updates: chart_data, chart_title, chart_loading
"""
import sqlite3
db_path = Path("data/pathways.db")
if not db_path.exists():
self.error_message = "Unable to generate chart. Database not found."
self.chart_data = []
return
self.chart_loading = True
try:
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
# Build WHERE clause for filters
where_clauses = []
params = []
# Drug filter (if any drugs selected)
if self.selected_drugs:
placeholders = ",".join("?" * len(self.selected_drugs))
where_clauses.append(f"drug_name_std IN ({placeholders})")
params.extend(self.selected_drugs)
# Directorate filter (if any directorates selected)
if self.selected_directorates:
placeholders = ",".join("?" * len(self.selected_directorates))
where_clauses.append(f"directory IN ({placeholders})")
params.extend(self.selected_directorates)
base_where = ""
if where_clauses:
base_where = "WHERE " + " AND ".join(where_clauses)
# Build date filter HAVING clauses for patient-level filtering
having_clauses = []
having_params = []
if self.initiated_filter_enabled and self.initiated_from_date:
having_clauses.append("first_seen >= ?")
having_params.append(self.initiated_from_date)
if self.initiated_filter_enabled and self.initiated_to_date:
having_clauses.append("first_seen <= ?")
having_params.append(self.initiated_to_date)
if self.last_seen_filter_enabled and self.last_seen_from_date:
having_clauses.append("last_seen >= ?")
having_params.append(self.last_seen_from_date)
if self.last_seen_filter_enabled and self.last_seen_to_date:
having_clauses.append("last_seen <= ?")
having_params.append(self.last_seen_to_date)
having_clause = ""
if having_clauses:
having_clause = "HAVING " + " AND ".join(having_clauses)
# Query to get aggregated data by Trust -> Directory -> Drug
# fact_interventions already has org_name, use it directly
chart_query = f"""
WITH filtered_patients AS (
SELECT upid
FROM (
SELECT
upid,
MIN(intervention_date) as first_seen,
MAX(intervention_date) as last_seen
FROM fact_interventions
{base_where}
GROUP BY upid
{having_clause}
)
),
patient_records AS (
SELECT
f.upid,
COALESCE(f.org_name, f.provider_code) as trust_name,
f.directory,
f.drug_name_std,
f.price_actual
FROM fact_interventions f
INNER JOIN filtered_patients fp ON f.upid = fp.upid
{base_where.replace('WHERE', 'AND') if base_where else ''}
)
SELECT
trust_name,
directory,
drug_name_std,
COUNT(DISTINCT upid) as patient_count,
COALESCE(SUM(price_actual), 0) as total_cost
FROM patient_records
GROUP BY trust_name, directory, drug_name_std
ORDER BY trust_name, directory, drug_name_std
"""
all_params = params + having_params
if where_clauses:
all_params.extend(params)
cursor.execute(chart_query, all_params)
rows = cursor.fetchall()
conn.close()
# Build hierarchical chart data
chart_data = []
hierarchy_totals = {} # Track totals for calculating color values
# Root node
root_id = "N&WICS"
chart_data.append({
"parents": "",
"ids": root_id,
"labels": "Norfolk & Waveney ICS",
"value": 0,
"cost": 0.0,
"colour": 1.0,
})
# Process rows to build hierarchy
trust_totals = {}
directory_totals = {}
drug_data = []
for row in rows:
trust_name, directory, drug_name, patient_count, cost = row
if not trust_name or not directory or not drug_name:
continue
# Trust level
trust_id = f"{root_id} - {trust_name}"
if trust_id not in trust_totals:
trust_totals[trust_id] = {"value": 0, "cost": 0.0, "label": trust_name}
trust_totals[trust_id]["value"] += patient_count
trust_totals[trust_id]["cost"] += cost
# Directory level
dir_id = f"{trust_id} - {directory}"
if dir_id not in directory_totals:
directory_totals[dir_id] = {
"value": 0,
"cost": 0.0,
"label": directory,
"parent": trust_id,
}
directory_totals[dir_id]["value"] += patient_count
directory_totals[dir_id]["cost"] += cost
# Drug level (leaf)
drug_id = f"{dir_id} - {drug_name}"
drug_data.append({
"ids": drug_id,
"labels": drug_name,
"parent": dir_id,
"value": patient_count,
"cost": float(cost),
})
# Calculate root total
root_total = sum(t["value"] for t in trust_totals.values())
root_cost = sum(t["cost"] for t in trust_totals.values())
chart_data[0]["value"] = root_total
chart_data[0]["cost"] = root_cost
# Add trust nodes with color proportions
for trust_id, data in trust_totals.items():
colour = data["value"] / root_total if root_total > 0 else 0
chart_data.append({
"parents": root_id,
"ids": trust_id,
"labels": data["label"],
"value": data["value"],
"cost": data["cost"],
"colour": colour,
})
# Add directory nodes with color proportions
for dir_id, data in directory_totals.items():
parent_total = trust_totals[data["parent"]]["value"]
colour = data["value"] / parent_total if parent_total > 0 else 0
chart_data.append({
"parents": data["parent"],
"ids": dir_id,
"labels": data["label"],
"value": data["value"],
"cost": data["cost"],
"colour": colour,
})
# Add drug nodes with color proportions
for drug in drug_data:
parent_dir = drug["parent"]
parent_total = directory_totals[parent_dir]["value"]
colour = drug["value"] / parent_total if parent_total > 0 else 0
chart_data.append({
"parents": parent_dir,
"ids": drug["ids"],
"labels": drug["labels"],
"value": drug["value"],
"cost": drug["cost"],
"colour": colour,
})
self.chart_data = chart_data
self.chart_title = self._generate_chart_title()
self.chart_loading = False
self.error_message = ""
except sqlite3.Error as e:
self.error_message = f"Unable to generate chart. Database error: {str(e)}"
self.chart_data = []
self.chart_loading = False
except Exception as e:
self.error_message = f"Unable to generate chart. Details: {str(e)}"
self.chart_data = []
self.chart_loading = False
def _generate_chart_title(self) -> str:
"""Generate chart title based on current filter state."""
parts = []
# Date range info
if self.last_seen_filter_enabled:
parts.append(f"Last seen: {self.last_seen_from_date} to {self.last_seen_to_date}")
elif self.initiated_filter_enabled:
parts.append(f"Initiated: {self.initiated_from_date} to {self.initiated_to_date}")
# Drug selection info
if self.selected_drugs:
if len(self.selected_drugs) <= 3:
parts.append(", ".join(self.selected_drugs))
else:
parts.append(f"{len(self.selected_drugs)} drugs selected")
# Directorate selection info
if self.selected_directorates:
if len(self.selected_directorates) <= 2:
parts.append(", ".join(self.selected_directorates))
else:
parts.append(f"{len(self.selected_directorates)} directorates")
if parts:
return " | ".join(parts)
return "All Patients"
# ========================================================================= # =========================================================================
# Plotly Chart Generation # Plotly Chart Generation
# ========================================================================= # =========================================================================
-446
View File
@@ -1,446 +0,0 @@
"""
Large dataset performance tests for the Patient Pathway Analysis tool.
This module tests the system's ability to handle realistic workloads:
1. Full dataset analysis (all drugs, trusts, directories)
2. Memory usage under load
3. Scalability characteristics
Run with: python -m pytest tests/test_large_dataset_performance.py -v
"""
import gc
import time
import tracemalloc
from datetime import date
from pathlib import Path
import pytest
# Mark all tests in this module as large dataset tests
pytestmark = pytest.mark.largedata
class TestLargeDatasetPerformance:
"""Performance tests with full dataset."""
@pytest.fixture(autouse=True)
def setup_paths(self):
"""Set up paths and verify data exists."""
from core import default_paths
from data_processing import get_loader
# Check if database exists
db_path = default_paths.data_dir / "pathways.db"
if not db_path.exists():
pytest.skip("SQLite database not found")
self.paths = default_paths
self.loader = get_loader('sqlite')
# Load data once
result = self.loader.load()
if result is None or result.df is None or len(result.df) == 0:
pytest.skip("No data available in database")
self.df = result.df
self.row_count = result.row_count
def test_data_load_time_acceptable(self):
"""Data loading should complete in under 5 seconds."""
from data_processing import get_loader
gc.collect()
start = time.perf_counter()
loader = get_loader('sqlite')
result = loader.load()
elapsed = time.perf_counter() - start
assert result is not None, "Data loading failed"
assert result.row_count > 0, "No data loaded"
# Allow 5 seconds for data loading
assert elapsed < 5.0, f"Data loading took {elapsed:.2f}s (target: <5s)"
def test_analysis_pipeline_completes(self):
"""Full analysis pipeline should complete without error."""
from analysis.pathway_analyzer import generate_icicle_chart
import pandas as pd
# Get available filters from actual data
trusts = self.df['Provider Code'].unique().tolist()[:20]
drugs = self.df['Drug Name'].dropna().unique().tolist()[:10]
directories = self.df['Directory'].dropna().unique().tolist()
# Load org codes for trust name mapping
org_codes = pd.read_csv(self.paths.org_codes_csv, index_col=1)
trust_names = []
for t in trusts:
if t in org_codes.index:
trust_names.append(org_codes.loc[t, 'Name'])
if not trust_names:
trust_names = org_codes['Name'].tolist()[:20]
# Run analysis with reasonable filter
ice_df, title = generate_icicle_chart(
df=self.df,
start_date="2020-01-01",
end_date="2025-01-01",
last_seen_date="2020-01-01",
trust_filter=trust_names,
drug_filter=drugs,
directory_filter=directories,
minimum_num_patients=1,
title="Large Dataset Test",
paths=self.paths,
)
# Should produce some results
assert ice_df is not None, "Analysis produced no results"
assert len(ice_df) > 0, "Analysis produced empty results"
def test_analysis_pipeline_time_acceptable(self):
"""Analysis pipeline should complete in under 60 seconds."""
from analysis.pathway_analyzer import generate_icicle_chart
import pandas as pd
# Get available filters from actual data
trusts = self.df['Provider Code'].unique().tolist()[:20]
drugs = self.df['Drug Name'].dropna().unique().tolist()[:10]
directories = self.df['Directory'].dropna().unique().tolist()
# Load org codes for trust name mapping
org_codes = pd.read_csv(self.paths.org_codes_csv, index_col=1)
trust_names = []
for t in trusts:
if t in org_codes.index:
trust_names.append(org_codes.loc[t, 'Name'])
if not trust_names:
trust_names = org_codes['Name'].tolist()[:20]
gc.collect()
start = time.perf_counter()
ice_df, title = generate_icicle_chart(
df=self.df,
start_date="2020-01-01",
end_date="2025-01-01",
last_seen_date="2020-01-01",
trust_filter=trust_names,
drug_filter=drugs,
directory_filter=directories,
minimum_num_patients=1,
title="Performance Test",
paths=self.paths,
)
elapsed = time.perf_counter() - start
# Allow 60 seconds for full analysis (observed ~19s with 440K rows)
assert elapsed < 60.0, f"Analysis took {elapsed:.2f}s (target: <60s)"
print(f"\n Analysis completed in {elapsed:.2f}s with {len(ice_df) if ice_df is not None else 0} result rows")
def test_memory_usage_acceptable(self):
"""Memory usage should not exceed 500MB during analysis."""
from analysis.pathway_analyzer import generate_icicle_chart
import pandas as pd
# Get available filters from actual data
trusts = self.df['Provider Code'].unique().tolist()[:15]
drugs = self.df['Drug Name'].dropna().unique().tolist()[:5]
directories = self.df['Directory'].dropna().unique().tolist()
# Load org codes for trust name mapping
org_codes = pd.read_csv(self.paths.org_codes_csv, index_col=1)
trust_names = []
for t in trusts:
if t in org_codes.index:
trust_names.append(org_codes.loc[t, 'Name'])
if not trust_names:
trust_names = org_codes['Name'].tolist()[:15]
gc.collect()
tracemalloc.start()
ice_df, title = generate_icicle_chart(
df=self.df,
start_date="2020-01-01",
end_date="2025-01-01",
last_seen_date="2020-01-01",
trust_filter=trust_names,
drug_filter=drugs,
directory_filter=directories,
minimum_num_patients=1,
title="Memory Test",
paths=self.paths,
)
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
peak_mb = peak / 1024 / 1024
# Allow 500MB peak memory
assert peak_mb < 500, f"Peak memory {peak_mb:.1f}MB exceeds 500MB limit"
print(f"\n Peak memory usage: {peak_mb:.1f}MB")
def test_figure_creation_scales(self):
"""Figure creation time should scale linearly with result size."""
from visualization.plotly_generator import create_icicle_figure
import pandas as pd
import numpy as np
# Test with different sizes
sizes = [100, 500, 1000, 2000]
times = []
for n_rows in sizes:
sample_data = {
'parents': ['N&WICS'] * n_rows,
'ids': [f'N&WICS - Test{i}' for i in range(n_rows)],
'labels': [f'Test{i}' for i in range(n_rows)],
'value': np.random.randint(1, 100, n_rows),
'colour': np.random.random(n_rows),
'cost': np.random.randint(1000, 100000, n_rows),
'costpp': np.random.randint(100, 10000, n_rows),
'cost_pp_pa': [str(np.random.randint(100, 10000)) for _ in range(n_rows)],
'First seen': pd.to_datetime(['2024-01-01'] * n_rows),
'Last seen': pd.to_datetime(['2024-12-31'] * n_rows),
'First seen (Parent)': ['2024-01-01'] * n_rows,
'Last seen (Parent)': ['2024-12-31'] * n_rows,
'average_spacing': ['Test spacing'] * n_rows,
'avg_days': pd.to_timedelta([100] * n_rows, unit='D'),
}
sample_df = pd.DataFrame(sample_data)
gc.collect()
start = time.perf_counter()
fig = create_icicle_figure(sample_df, f"Scale Test {n_rows}")
elapsed = time.perf_counter() - start
times.append(elapsed)
# Check that time scaling is roughly linear (not exponential)
# If time doubles when size doubles, it's linear
# We allow some variance, so check that 10x data doesn't take more than 20x time
time_ratio = times[-1] / times[0]
size_ratio = sizes[-1] / sizes[0]
# Allow 3x the expected linear scaling
max_allowed_ratio = size_ratio * 3
assert time_ratio < max_allowed_ratio, (
f"Figure creation doesn't scale well: "
f"{sizes[-1]} rows took {times[-1]:.3f}s vs {sizes[0]} rows at {times[0]:.3f}s "
f"(ratio {time_ratio:.1f}x, expected <{max_allowed_ratio:.1f}x)"
)
print(f"\n Figure scaling: {sizes[0]} rows: {times[0]*1000:.1f}ms, "
f"{sizes[-1]} rows: {times[-1]*1000:.1f}ms (ratio: {time_ratio:.1f}x)")
class TestDataVolumeStress:
"""Stress tests to verify system handles various data volumes."""
@pytest.fixture(autouse=True)
def setup_paths(self):
"""Set up paths and verify data exists."""
from core import default_paths
from data_processing import get_loader
# Check if database exists
db_path = default_paths.data_dir / "pathways.db"
if not db_path.exists():
pytest.skip("SQLite database not found")
self.paths = default_paths
self.loader = get_loader('sqlite')
# Load data once
result = self.loader.load()
if result is None or result.df is None or len(result.df) == 0:
pytest.skip("No data available in database")
self.df = result.df
def test_handles_all_drugs(self):
"""Analysis can handle filtering by all drugs."""
from analysis.pathway_analyzer import prepare_data
import pandas as pd
all_drugs = self.df['Drug Name'].dropna().unique().tolist()
# Load org codes
org_codes = pd.read_csv(self.paths.org_codes_csv, index_col=1)
trust_names = org_codes['Name'].tolist()[:5]
result = prepare_data(
df=self.df,
trust_filter=trust_names,
drug_filter=all_drugs,
directory_filter=self.df['Directory'].dropna().unique().tolist(),
paths=self.paths,
)
# Should complete without error (returns tuple)
assert result is not None
assert len(result) == 3 # (df, org_codes, directory_df)
def test_handles_all_trusts(self):
"""Analysis can handle filtering by all trusts."""
from analysis.pathway_analyzer import prepare_data
import pandas as pd
# Load org codes
org_codes = pd.read_csv(self.paths.org_codes_csv, index_col=1)
all_trust_names = org_codes['Name'].tolist()
result = prepare_data(
df=self.df,
trust_filter=all_trust_names,
drug_filter=['ADALIMUMAB', 'ETANERCEPT'],
directory_filter=self.df['Directory'].dropna().unique().tolist(),
paths=self.paths,
)
# Should complete without error (returns tuple)
assert result is not None
assert len(result) == 3 # (df, org_codes, directory_df)
def test_handles_wide_date_range(self):
"""Analysis can handle a wide date range via generate_icicle_chart."""
from analysis.pathway_analyzer import generate_icicle_chart
import pandas as pd
# Load org codes
org_codes = pd.read_csv(self.paths.org_codes_csv, index_col=1)
trust_names = org_codes['Name'].tolist()[:10]
# Use very wide date range via full pipeline
ice_df, title = generate_icicle_chart(
df=self.df,
start_date="2010-01-01",
end_date="2030-01-01",
last_seen_date="2010-01-01",
trust_filter=trust_names,
drug_filter=self.df['Drug Name'].dropna().unique().tolist()[:5],
directory_filter=self.df['Directory'].dropna().unique().tolist(),
minimum_num_patients=1,
title="Wide Date Range Test",
paths=self.paths,
)
# Should complete without error
assert ice_df is not None or ice_df is None # Just verifying no exception
def test_handles_minimum_patient_threshold(self):
"""Analysis correctly applies minimum patient threshold."""
from analysis.pathway_analyzer import generate_icicle_chart
import pandas as pd
# Load org codes
org_codes = pd.read_csv(self.paths.org_codes_csv, index_col=1)
trust_names = org_codes['Name'].tolist()[:10]
# Run with minimum 50 patients
ice_df_50, _ = generate_icicle_chart(
df=self.df,
start_date="2020-01-01",
end_date="2025-01-01",
last_seen_date="2020-01-01",
trust_filter=trust_names,
drug_filter=self.df['Drug Name'].dropna().unique().tolist()[:5],
directory_filter=self.df['Directory'].dropna().unique().tolist(),
minimum_num_patients=50,
title="Threshold Test 50",
paths=self.paths,
)
# Run with minimum 1 patient
ice_df_1, _ = generate_icicle_chart(
df=self.df,
start_date="2020-01-01",
end_date="2025-01-01",
last_seen_date="2020-01-01",
trust_filter=trust_names,
drug_filter=self.df['Drug Name'].dropna().unique().tolist()[:5],
directory_filter=self.df['Directory'].dropna().unique().tolist(),
minimum_num_patients=1,
title="Threshold Test 1",
paths=self.paths,
)
# Higher threshold should produce fewer or equal results
len_50 = len(ice_df_50) if ice_df_50 is not None else 0
len_1 = len(ice_df_1) if ice_df_1 is not None else 0
assert len_50 <= len_1, (
f"Higher minimum threshold should produce fewer results: "
f"min=50 gave {len_50} rows, min=1 gave {len_1} rows"
)
class TestConcurrentOperations:
"""Tests for handling multiple operations."""
@pytest.fixture(autouse=True)
def setup_paths(self):
"""Set up paths and verify data exists."""
from core import default_paths
from data_processing import get_loader
# Check if database exists
db_path = default_paths.data_dir / "pathways.db"
if not db_path.exists():
pytest.skip("SQLite database not found")
self.paths = default_paths
def test_multiple_data_loads(self):
"""Multiple data loads should not cause issues."""
from data_processing import get_loader
results = []
for i in range(3):
loader = get_loader('sqlite')
result = loader.load()
if result is not None:
results.append(result.row_count)
# All loads should return same row count
assert len(set(results)) == 1, f"Inconsistent row counts: {results}"
def test_sequential_analyses(self):
"""Multiple sequential analyses should complete."""
from analysis.pathway_analyzer import generate_icicle_chart
from data_processing import get_loader
import pandas as pd
# Load data
loader = get_loader('sqlite')
result = loader.load()
if result is None or result.df is None:
pytest.skip("No data available")
df = result.df
# Load org codes
org_codes = pd.read_csv(self.paths.org_codes_csv, index_col=1)
trust_names = org_codes['Name'].tolist()[:5]
# Run multiple analyses
for i in range(3):
ice_df, title = generate_icicle_chart(
df=df,
start_date="2020-01-01",
end_date="2025-01-01",
last_seen_date="2020-01-01",
trust_filter=trust_names,
drug_filter=['ADALIMUMAB'],
directory_filter=df['Directory'].dropna().unique().tolist(),
minimum_num_patients=1,
title=f"Sequential Test {i+1}",
paths=self.paths,
)
# Each should complete
assert ice_df is not None or ice_df is None # Just check no error