diff --git a/cli/refresh_pathways.py b/cli/refresh_pathways.py index 1f93263..f8c9731 100644 --- a/cli/refresh_pathways.py +++ b/cli/refresh_pathways.py @@ -176,6 +176,7 @@ def log_refresh_complete( record_count: int, date_filter_counts: dict[str, int], duration_seconds: float, + source_row_count: Optional[int] = None, ) -> None: """Log the successful completion of a refresh operation.""" conn.execute(""" @@ -184,13 +185,15 @@ def log_refresh_complete( status = 'completed', record_count = ?, date_filter_counts = ?, - processing_duration_seconds = ? + processing_duration_seconds = ?, + source_row_count = ? WHERE refresh_id = ? """, ( datetime.now().isoformat(), record_count, json.dumps(date_filter_counts), duration_seconds, + source_row_count, refresh_id, )) conn.commit() @@ -517,6 +520,7 @@ def refresh_pathways( record_count=stats["total_records"], date_filter_counts=stats["date_filter_counts"], duration_seconds=elapsed, + source_row_count=stats.get("snowflake_rows"), ) # Verify final counts diff --git a/data_processing/__init__.py b/data_processing/__init__.py index fe9f16d..c636ddd 100644 --- a/data_processing/__init__.py +++ b/data_processing/__init__.py @@ -24,15 +24,6 @@ from data_processing.schema import ( REF_DRUG_DIRECTORY_MAP_SCHEMA, REF_DRUG_INDICATION_CLUSTERS_SCHEMA, REFERENCE_TABLES_SCHEMA, - # Fact table schemas - FACT_INTERVENTIONS_SCHEMA, - FACT_TABLES_SCHEMA, - # Materialized view schemas - MV_PATIENT_TREATMENT_SUMMARY_SCHEMA, - MATERIALIZED_VIEWS_SCHEMA, - # File tracking schemas - PROCESSED_FILES_SCHEMA, - FILE_TRACKING_SCHEMA, # Combined schema ALL_TABLES_SCHEMA, # Reference table functions @@ -40,16 +31,6 @@ from data_processing.schema import ( drop_reference_tables, get_reference_table_counts, verify_reference_tables_exist, - # Fact table functions - create_fact_tables, - drop_fact_tables, - get_fact_table_counts, - verify_fact_tables_exist, - # File tracking functions - create_file_tracking_tables, - drop_file_tracking_tables, - get_file_tracking_counts, - verify_file_tracking_tables_exist, # Combined functions create_all_tables, drop_all_tables, @@ -81,27 +62,12 @@ from data_processing.reference_data import ( from data_processing.loader import ( DataLoader, FileDataLoader, - SQLiteDataLoader, LoadResult, get_loader, REQUIRED_COLUMNS, OPTIONAL_COLUMNS, ) -# Patient data migration functions -from data_processing.patient_data import ( - PatientDataLoadResult, - load_patient_data, - get_patient_data_stats, - list_processed_files, - calculate_file_hash, - # Materialized view functions - MVRefreshResult, - refresh_patient_treatment_summary, - get_patient_summary_stats, - verify_mv_consistency, -) - # Snowflake connector from data_processing.snowflake_connector import ( SnowflakeConnector, @@ -165,15 +131,6 @@ __all__ = [ "REF_DRUG_DIRECTORY_MAP_SCHEMA", "REF_DRUG_INDICATION_CLUSTERS_SCHEMA", "REFERENCE_TABLES_SCHEMA", - # Fact table schemas - "FACT_INTERVENTIONS_SCHEMA", - "FACT_TABLES_SCHEMA", - # Materialized view schemas - "MV_PATIENT_TREATMENT_SUMMARY_SCHEMA", - "MATERIALIZED_VIEWS_SCHEMA", - # File tracking schemas - "PROCESSED_FILES_SCHEMA", - "FILE_TRACKING_SCHEMA", # Combined schema "ALL_TABLES_SCHEMA", # Reference table functions @@ -181,16 +138,6 @@ __all__ = [ "drop_reference_tables", "get_reference_table_counts", "verify_reference_tables_exist", - # Fact table functions - "create_fact_tables", - "drop_fact_tables", - "get_fact_table_counts", - "verify_fact_tables_exist", - # File tracking functions - "create_file_tracking_tables", - "drop_file_tracking_tables", - "get_file_tracking_counts", - "verify_file_tracking_tables_exist", # Combined functions "create_all_tables", "drop_all_tables", @@ -216,22 +163,10 @@ __all__ = [ # Data loader abstractions "DataLoader", "FileDataLoader", - "SQLiteDataLoader", "LoadResult", "get_loader", "REQUIRED_COLUMNS", "OPTIONAL_COLUMNS", - # Patient data migration - "PatientDataLoadResult", - "load_patient_data", - "get_patient_data_stats", - "list_processed_files", - "calculate_file_hash", - # Materialized view functions - "MVRefreshResult", - "refresh_patient_treatment_summary", - "get_patient_summary_stats", - "verify_mv_consistency", # Snowflake connector "SnowflakeConnector", "SnowflakeConnectionError", diff --git a/data_processing/data_source.py b/data_processing/data_source.py index c4f1a1d..4d921f5 100644 --- a/data_processing/data_source.py +++ b/data_processing/data_source.py @@ -232,9 +232,9 @@ class DataSourceManager: ) def _check_sqlite_status(self) -> SourceStatus: - """Check if SQLite database is available with data.""" + """Check if SQLite database is available with pathway data.""" try: - from data_processing.database import default_db_manager, default_db_config + from data_processing.database import default_db_config db_path = self._sqlite_db_path or Path(default_db_config.db_path) @@ -252,22 +252,22 @@ class DataSourceManager: config = DatabaseConfig(db_path=db_path) manager = DatabaseManager(config) - if not manager.table_exists("fact_interventions"): + if not manager.table_exists("pathway_nodes"): return SourceStatus( source_type=DataSourceType.SQLITE, available=False, configured=True, - message="fact_interventions table not found", + message="pathway_nodes table not found", last_checked=datetime.now(), ) - count = manager.get_table_count("fact_interventions") + count = manager.get_table_count("pathway_nodes") if count == 0: return SourceStatus( source_type=DataSourceType.SQLITE, available=False, configured=True, - message="fact_interventions table is empty", + message="pathway_nodes table is empty", last_checked=datetime.now(), ) @@ -275,7 +275,7 @@ class DataSourceManager: source_type=DataSourceType.SQLITE, available=True, configured=True, - message=f"SQLite database ready ({count:,} rows)", + message=f"SQLite database ready ({count:,} pathway nodes)", last_checked=datetime.now(), ) except Exception as e: @@ -535,50 +535,14 @@ class DataSourceManager: drugs: Optional[list[str]], directories: Optional[list[str]], ) -> Optional[DataSourceResult]: - """Try to get data from SQLite.""" - import time + """Try to get data from SQLite. - try: - from data_processing.loader import SQLiteDataLoader - - # Determine database path - db_path = self._sqlite_db_path - if db_path is None: - from data_processing.database import default_db_config - db_path = Path(default_db_config.db_path) - - loader = SQLiteDataLoader( - db_path=db_path, - date_range=(start_date, end_date) if start_date and end_date else None, - trusts=trusts, - drugs=drugs, - directories=directories, - ) - - # Check if source is valid - is_valid, msg = loader.validate_source() - if not is_valid: - logger.debug(f"SQLite not available: {msg}") - return None - - start_time = time.time() - result = loader.load() - load_time = time.time() - start_time - - logger.info(f"SQLite loaded {result.row_count} rows in {load_time:.2f}s") - - return DataSourceResult( - df=result.df, - source_type=DataSourceType.SQLITE, - source_detail=str(db_path), - row_count=result.row_count, - cached=False, - from_fallback=False, - load_time_seconds=load_time, - ) - except Exception as e: - logger.warning(f"SQLite query failed: {e}") - return None + Note: Raw intervention data is no longer stored in SQLite. + The app now uses pre-computed pathway_nodes via load_pathway_data(). + This fallback is retained for interface compatibility but always returns None. + """ + logger.debug("SQLite raw data fallback skipped (fact_interventions removed)") + return None def _try_file( self, diff --git a/data_processing/diagnosis_lookup.py b/data_processing/diagnosis_lookup.py index 0d90d8e..02ea5aa 100644 --- a/data_processing/diagnosis_lookup.py +++ b/data_processing/diagnosis_lookup.py @@ -78,42 +78,6 @@ class DrugIndicationMatchRate: sample_unmatched: list[str] = field(default_factory=list) # Sample patient IDs -@dataclass -class DrugSnomedMapping: - """SNOMED code mapping for a drug from ref_drug_snomed_mapping.""" - snomed_code: str - snomed_description: str - search_term: str - primary_directorate: str - indication: str = "" - ta_id: str = "" - - -@dataclass -class DirectSnomedMatchResult: - """Result of direct SNOMED code lookup in GP records.""" - patient_pseudonym: str - matched: bool - snomed_code: Optional[str] = None - snomed_description: Optional[str] = None - search_term: Optional[str] = None - primary_directorate: Optional[str] = None - event_date: Optional[datetime] = None - source: str = "DIRECT_SNOMED" # DIRECT_SNOMED | NONE - - -@dataclass -class DirectorateAssignment: - """Result of directorate assignment for a patient-drug combination.""" - upid: str - drug_name: str - directorate: Optional[str] - search_term: Optional[str] = None - source: str = "FALLBACK" # DIAGNOSIS | FALLBACK - snomed_code: Optional[str] = None - event_date: Optional[datetime] = None - - def get_drug_clusters( drug_name: str, db_manager: Optional[DatabaseManager] = None @@ -180,266 +144,6 @@ def get_drug_cluster_ids( return list(set(c["cluster_id"] for c in clusters)) -def get_drug_snomed_codes( - drug_name: str, - db_manager: Optional[DatabaseManager] = None -) -> list[DrugSnomedMapping]: - """ - Get all SNOMED codes for a drug from local ref_drug_snomed_mapping table. - - This uses the enriched mapping CSV data loaded into SQLite, which provides - direct SNOMED-to-drug mappings with Search_Term and PrimaryDirectorate. - - Args: - drug_name: Drug name to look up (case-insensitive, matches cleaned_drug_name) - db_manager: Optional DatabaseManager (defaults to default_db_manager) - - Returns: - List of DrugSnomedMapping with snomed_code, snomed_description, - search_term, primary_directorate, indication, ta_id - """ - if db_manager is None: - db_manager = default_db_manager - - query = """ - SELECT DISTINCT - snomed_code, - snomed_description, - search_term, - primary_directorate, - indication, - ta_id - FROM ref_drug_snomed_mapping - WHERE UPPER(cleaned_drug_name) = UPPER(?) - OR UPPER(drug_name) = UPPER(?) - ORDER BY search_term, snomed_code - """ - - try: - with db_manager.get_connection() as conn: - cursor = conn.execute(query, (drug_name, drug_name)) - rows = cursor.fetchall() - - results = [] - for row in rows: - results.append(DrugSnomedMapping( - snomed_code=row["snomed_code"], - snomed_description=row["snomed_description"] or "", - search_term=row["search_term"] or "", - primary_directorate=row["primary_directorate"] or "", - indication=row["indication"] or "", - ta_id=row["ta_id"] or "", - )) - - logger.debug(f"Found {len(results)} SNOMED mappings for drug '{drug_name}'") - return results - - except Exception as e: - logger.error(f"Error getting SNOMED codes for drug '{drug_name}': {e}") - return [] - - -def patient_has_indication_direct( - patient_pseudonym: str, - drug_snomed_mappings: list[DrugSnomedMapping], - connector: Optional[SnowflakeConnector] = None, - before_date: Optional[date] = None, -) -> DirectSnomedMatchResult: - """ - Check if patient has any of the SNOMED codes in their GP records. - - This is the direct SNOMED lookup - it queries PrimaryCareClinicalCoding - for exact SNOMED code matches (not via cluster). Returns the most recent - match by EventDateTime if multiple matches exist. - - Args: - patient_pseudonym: Patient's pseudonymised NHS number - drug_snomed_mappings: List of DrugSnomedMapping from get_drug_snomed_codes() - connector: Optional SnowflakeConnector (defaults to singleton) - before_date: Optional date - only check diagnoses before this date - - Returns: - DirectSnomedMatchResult with match details (most recent by EventDateTime) - """ - result = DirectSnomedMatchResult( - patient_pseudonym=patient_pseudonym, - matched=False, - source="NONE", - ) - - if not drug_snomed_mappings: - return result - - if not SNOWFLAKE_AVAILABLE: - logger.warning("Snowflake connector not available") - return result - - if not is_snowflake_configured(): - logger.warning("Snowflake not configured - cannot check GP records") - return result - - if connector is None: - connector = get_connector() - - # Build lookup dict for mapping snomed_code -> (search_term, primary_directorate, snomed_description) - snomed_lookup = { - m.snomed_code: (m.search_term, m.primary_directorate, m.snomed_description) - for m in drug_snomed_mappings - } - - # Get unique SNOMED codes - snomed_codes = list(snomed_lookup.keys()) - - # Build placeholders for SNOMED codes - placeholders = ", ".join(["%s"] * len(snomed_codes)) - - # Query to find most recent matching SNOMED code in GP records - query = f''' - SELECT - "SNOMEDCode", - "EventDateTime" - FROM DATA_HUB.PHM."PrimaryCareClinicalCoding" - WHERE "PatientPseudonym" = %s - AND "SNOMEDCode" IN ({placeholders}) - ''' - - params: list = [patient_pseudonym] + snomed_codes - - if before_date: - query += ' AND "EventDateTime" < %s' - params.append(before_date.isoformat()) - - query += ' ORDER BY "EventDateTime" DESC LIMIT 1' - - try: - results = connector.execute_dict(query, tuple(params)) - - if results: - row = results[0] - matched_code = row.get("SNOMEDCode") - event_dt = row.get("EventDateTime") - - if matched_code and matched_code in snomed_lookup: - search_term, primary_dir, snomed_desc = snomed_lookup[matched_code] - - return DirectSnomedMatchResult( - patient_pseudonym=patient_pseudonym, - matched=True, - snomed_code=matched_code, - snomed_description=snomed_desc, - search_term=search_term, - primary_directorate=primary_dir, - event_date=event_dt, - source="DIRECT_SNOMED", - ) - - return result - - except Exception as e: - logger.error(f"Error checking direct SNOMED for patient '{patient_pseudonym}': {e}") - return result - - -def get_directorate_from_diagnosis( - upid: str, - drug_name: str, - connector: Optional[SnowflakeConnector] = None, - db_manager: Optional[DatabaseManager] = None, - before_date: Optional[date] = None, -) -> DirectorateAssignment: - """ - Get directorate assignment for a patient-drug combination using diagnosis-based lookup. - - This function attempts to assign a directorate based on the patient's GP records - (direct SNOMED code matching). If no match is found, it returns a FALLBACK result - indicating that the caller should use alternative assignment methods (e.g., - department_identification() from tools/data.py). - - Workflow: - 1. Get all SNOMED codes for the drug from ref_drug_snomed_mapping - 2. Query patient's GP records for matching SNOMED codes - 3. If match found → return diagnosis-based directorate and search_term - 4. If no match → return FALLBACK result (caller handles fallback logic) - - Args: - upid: Patient's unique patient ID (Provider Code[:3] + PersonKey) - drug_name: Drug name to look up - connector: Optional SnowflakeConnector (defaults to singleton) - db_manager: Optional DatabaseManager (defaults to default_db_manager) - before_date: Optional date - only check diagnoses before this date - - Returns: - DirectorateAssignment with directorate, search_term, and source - """ - result = DirectorateAssignment( - upid=upid, - drug_name=drug_name, - directorate=None, - source="FALLBACK", - ) - - # Step 1: Get SNOMED codes for the drug - drug_snomed_mappings = get_drug_snomed_codes(drug_name, db_manager) - - if not drug_snomed_mappings: - logger.debug(f"No SNOMED mappings found for drug '{drug_name}' - using fallback") - return result - - # Step 2: Check Snowflake availability - if not SNOWFLAKE_AVAILABLE: - logger.debug("Snowflake not available - using fallback") - return result - - if not is_snowflake_configured(): - logger.debug("Snowflake not configured - using fallback") - return result - - # Step 3: Get patient pseudonym from UPID - # UPID format is Provider Code (3 chars) + PersonKey - # We need to query Snowflake to get the PatientPseudonym for this PersonKey - # However, patient_has_indication_direct expects PatientPseudonym, not UPID - # For now, we'll use UPID as the identifier - the actual integration - # will need to happen at the DataFrame level where we have PersonKey - # - # NOTE: This function will be called from the pipeline where we have - # access to PatientPseudonym. The UPID is passed for logging/tracking. - - # Actually, looking at the pipeline, we need PatientPseudonym, not UPID. - # The caller should pass the PatientPseudonym or we need to look it up. - # For now, let's assume the caller will use this in a batch context - # where they can map UPID -> PatientPseudonym. - - # Let me reconsider: the function signature takes UPID but we need - # PatientPseudonym for Snowflake. In the pipeline context (fetch_and_transform_data), - # we'll have the PersonKey column which IS the PatientPseudonym. - # So UPID = ProviderCode[:3] + PersonKey, and PersonKey = PatientPseudonym. - # - # We can extract PatientPseudonym from UPID by removing the first 3 chars. - patient_pseudonym = upid[3:] if len(upid) > 3 else upid - - # Step 4: Check patient's GP records for matching SNOMED codes - match_result = patient_has_indication_direct( - patient_pseudonym=patient_pseudonym, - drug_snomed_mappings=drug_snomed_mappings, - connector=connector, - before_date=before_date, - ) - - if match_result.matched and match_result.primary_directorate: - return DirectorateAssignment( - upid=upid, - drug_name=drug_name, - directorate=match_result.primary_directorate, - search_term=match_result.search_term, - source="DIAGNOSIS", - snomed_code=match_result.snomed_code, - event_date=match_result.event_date, - ) - - # No match found - return fallback result - return result - - def get_cluster_snomed_codes( cluster_id: str, connector: Optional[SnowflakeConnector] = None, @@ -864,229 +568,6 @@ def get_available_clusters( return [] -def batch_lookup_indication_groups( - df: "pd.DataFrame", - connector: Optional[SnowflakeConnector] = None, - db_manager: Optional[DatabaseManager] = None, - batch_size: int = 500, -) -> "pd.DataFrame": - """ - Batch lookup GP diagnosis-based indication groups for a DataFrame of patients. - - This is the efficient batch version of get_directorate_from_diagnosis(). - Instead of querying Snowflake per patient, it batches the lookups for performance. - - Strategy: - 1. Get all unique (PersonKey, Drug Name) pairs from DataFrame - 2. For each unique drug, get all SNOMED codes from local SQLite - 3. Build batched Snowflake queries to check GP records - 4. Return indication_df mapping UPID → Indication_Group - - For unmatched patients, Indication_Group will be their Directory (with suffix). - - Args: - df: DataFrame with columns: UPID, Drug Name, Directory, PersonKey - connector: Optional SnowflakeConnector (defaults to singleton) - db_manager: Optional DatabaseManager (defaults to default_db_manager) - batch_size: Number of patients per Snowflake query batch - - Returns: - DataFrame with columns: UPID, Indication_Group, Source - - Indication_Group: Search_Term (if matched) or "Directory (no GP dx)" (if not) - - Source: "DIAGNOSIS" or "FALLBACK" - """ - import pandas as pd - - if db_manager is None: - db_manager = default_db_manager - - logger.info(f"Starting batch indication lookup for {len(df)} records...") - - # Step 1: Get unique (UPID, Drug Name, PseudoNHSNoLinked, Directory) combinations - # We need PseudoNHSNoLinked to query Snowflake - this matches PatientPseudonym in GP records - # Note: PersonKey is LocalPatientID which is provider-specific and does NOT match GP records - if 'PseudoNHSNoLinked' not in df.columns: - logger.error("DataFrame missing 'PseudoNHSNoLinked' column - cannot lookup GP records") - # Return fallback for all patients - result_df = df[['UPID', 'Directory']].drop_duplicates().copy() - result_df['Indication_Group'] = result_df['Directory'] + " (no GP dx)" - result_df['Source'] = "FALLBACK" - return result_df[['UPID', 'Indication_Group', 'Source']] - - # Get unique patient-drug combinations (we need one lookup per patient-drug pair) - unique_pairs = df[['UPID', 'Drug Name', 'PseudoNHSNoLinked', 'Directory']].drop_duplicates() - logger.info(f"Found {len(unique_pairs)} unique patient-drug combinations") - - # Step 2: Get all unique drugs and their SNOMED codes - unique_drugs = unique_pairs['Drug Name'].unique() - logger.info(f"Building SNOMED lookup for {len(unique_drugs)} unique drugs...") - - # Build drug -> list of DrugSnomedMapping dict - drug_snomed_map: dict[str, list[DrugSnomedMapping]] = {} - all_snomed_codes: set[str] = set() - snomed_to_drug_searchterm: dict[str, list[tuple[str, str, str]]] = {} # snomed -> [(drug, search_term, primary_dir), ...] - - for drug_name in unique_drugs: - mappings = get_drug_snomed_codes(drug_name, db_manager) - drug_snomed_map[drug_name] = mappings - - for m in mappings: - all_snomed_codes.add(m.snomed_code) - if m.snomed_code not in snomed_to_drug_searchterm: - snomed_to_drug_searchterm[m.snomed_code] = [] - snomed_to_drug_searchterm[m.snomed_code].append( - (drug_name, m.search_term, m.primary_directorate) - ) - - logger.info(f"Total SNOMED codes to check: {len(all_snomed_codes)}") - - # Step 3: Check Snowflake availability - if not SNOWFLAKE_AVAILABLE or not is_snowflake_configured(): - logger.warning("Snowflake not available - returning fallback for all patients") - result_df = unique_pairs[['UPID', 'Directory']].copy() - result_df['Indication_Group'] = result_df['Directory'] + " (no GP dx)" - result_df['Source'] = "FALLBACK" - return result_df[['UPID', 'Indication_Group', 'Source']].drop_duplicates(subset=['UPID']) - - if connector is None: - connector = get_connector() - - # Step 4: Query GP records for all patients in batches - # The query finds the most recent matching SNOMED code for each patient - - # Get unique PseudoNHSNoLinked values (each = one patient in GP records) - unique_patients = unique_pairs[['PseudoNHSNoLinked', 'UPID', 'Directory']].drop_duplicates(subset=['PseudoNHSNoLinked']) - patient_pseudonyms = unique_patients['PseudoNHSNoLinked'].tolist() - - logger.info(f"Querying GP records for {len(patient_pseudonyms)} unique patients in batches of {batch_size}...") - - # Results dict: PersonKey -> (snomed_code, event_date) - gp_matches: dict[str, tuple[str, Any]] = {} - - # Convert SNOMED codes to list for query - snomed_list = list(all_snomed_codes) - - if not snomed_list: - logger.warning("No SNOMED codes to check - returning fallback for all patients") - result_df = unique_pairs[['UPID', 'Directory']].copy() - result_df['Indication_Group'] = result_df['Directory'] + " (no GP dx)" - result_df['Source'] = "FALLBACK" - return result_df[['UPID', 'Indication_Group', 'Source']].drop_duplicates(subset=['UPID']) - - # Build SNOMED IN clause (reused across batches) - snomed_placeholders = ", ".join(["%s"] * len(snomed_list)) - - # Process patients in batches - for batch_start in range(0, len(patient_pseudonyms), batch_size): - batch_end = min(batch_start + batch_size, len(patient_pseudonyms)) - batch_pseudonyms = patient_pseudonyms[batch_start:batch_end] - - logger.info(f"Batch {batch_start//batch_size + 1}: patients {batch_start} to {batch_end}") - - # Build patient IN clause - patient_placeholders = ", ".join(["%s"] * len(batch_pseudonyms)) - - # Query to find all matching SNOMED codes for these patients - # We'll get all matches and pick the most recent per patient in Python - query = f''' - SELECT - "PatientPseudonym", - "SNOMEDCode", - "EventDateTime" - FROM DATA_HUB.PHM."PrimaryCareClinicalCoding" - WHERE "PatientPseudonym" IN ({patient_placeholders}) - AND "SNOMEDCode" IN ({snomed_placeholders}) - ORDER BY "PatientPseudonym", "EventDateTime" DESC - ''' - - params = tuple(batch_pseudonyms) + tuple(snomed_list) - - try: - results = connector.execute_dict(query, params) - - # Process results - pick most recent per patient - for row in results: - person_key = row.get("PatientPseudonym") - snomed_code = row.get("SNOMEDCode") - event_date = row.get("EventDateTime") - - if person_key and snomed_code: - # Keep only if we haven't seen this patient yet (first = most recent due to ORDER BY) - if person_key not in gp_matches: - gp_matches[person_key] = (snomed_code, event_date) - - except Exception as e: - logger.error(f"Error querying GP records for batch: {e}") - # Continue with other batches - - logger.info(f"Found GP matches for {len(gp_matches)} patients") - - # Step 5: Build result DataFrame - # For each unique_pair, determine Indication_Group based on match status - results_list = [] - - # We need to dedupe by UPID - a patient might be on multiple drugs - # Strategy: For each UPID, use the most recent match (if any) - upid_to_match: dict[str, tuple[str, str]] = {} # UPID -> (Indication_Group, Source) - - for _, row in unique_pairs.iterrows(): - upid = row['UPID'] - drug_name = row['Drug Name'] - patient_pseudonym = row['PseudoNHSNoLinked'] - directory = row['Directory'] - - # Check if patient has GP match (using PseudoNHSNoLinked which matches PatientPseudonym in GP) - if patient_pseudonym in gp_matches: - matched_snomed, event_date = gp_matches[patient_pseudonym] - - # Find the search_term for this SNOMED code and drug - # (A SNOMED code might map to multiple drugs with different search_terms) - if matched_snomed in snomed_to_drug_searchterm: - # Look for match with current drug first - search_term = None - for drug, st, pd in snomed_to_drug_searchterm[matched_snomed]: - if drug.upper() == drug_name.upper(): - search_term = st - break - # If no drug-specific match, use any match - if search_term is None: - search_term = snomed_to_drug_searchterm[matched_snomed][0][1] - - # Only update if we don't have a match for this UPID yet - if upid not in upid_to_match: - upid_to_match[upid] = (search_term, "DIAGNOSIS") - else: - # Shouldn't happen but fallback just in case - if upid not in upid_to_match: - upid_to_match[upid] = (directory + " (no GP dx)", "FALLBACK") - else: - # No GP match - use fallback - if upid not in upid_to_match: - upid_to_match[upid] = (directory + " (no GP dx)", "FALLBACK") - - # Build result DataFrame - for upid, (indication_group, source) in upid_to_match.items(): - results_list.append({ - 'UPID': upid, - 'Indication_Group': indication_group, - 'Source': source, - }) - - result_df = pd.DataFrame(results_list) - - # Log statistics - diagnosis_count = len([s for s in result_df['Source'] if s == "DIAGNOSIS"]) - fallback_count = len([s for s in result_df['Source'] if s == "FALLBACK"]) - total = len(result_df) - - logger.info(f"Indication lookup complete:") - logger.info(f" Total unique patients: {total}") - logger.info(f" DIAGNOSIS matches: {diagnosis_count} ({100*diagnosis_count/total:.1f}%)") - logger.info(f" FALLBACK (no GP match): {fallback_count} ({100*fallback_count/total:.1f}%)") - - return result_df - - # === Drug-to-indication mapping from DimSearchTerm.csv === @@ -1713,10 +1194,7 @@ __all__ = [ "ClusterSnomedCodes", "IndicationValidationResult", "DrugIndicationMatchRate", - "DrugSnomedMapping", - "DirectSnomedMatchResult", - "DirectorateAssignment", - # Cluster-based lookup functions (existing) + # Cluster-based lookup functions "get_drug_clusters", "get_drug_cluster_ids", "get_cluster_snomed_codes", @@ -1725,20 +1203,13 @@ __all__ = [ "get_indication_match_rate", "batch_validate_indications", "get_available_clusters", - # Direct SNOMED lookup functions (new) - "get_drug_snomed_codes", - "patient_has_indication_direct", - # Diagnosis-based directorate assignment - "get_directorate_from_diagnosis", - # Batch lookup for indication groups - "batch_lookup_indication_groups", # Drug-indication mapping from DimSearchTerm.csv "SEARCH_TERM_MERGE_MAP", "load_drug_indication_mapping", "get_search_terms_for_drug", # Drug-aware indication assignment "assign_drug_indications", - # Snowflake-direct indication lookup (new approach) + # Snowflake-direct indication lookup "get_patient_indication_groups", "CLUSTER_MAPPING_SQL", ] diff --git a/data_processing/load_snomed_mapping.py b/data_processing/load_snomed_mapping.py deleted file mode 100644 index c10821c..0000000 --- a/data_processing/load_snomed_mapping.py +++ /dev/null @@ -1,401 +0,0 @@ -""" -Load enriched SNOMED mapping data into SQLite database. - -This module loads the drug_snomed_mapping_enriched.csv file into the -ref_drug_snomed_mapping table for direct GP record matching. - -Source file: data/drug_snomed_mapping_enriched.csv (163K rows) -Target table: ref_drug_snomed_mapping - -Usage: - python -m data_processing.load_snomed_mapping - -Columns mapped: - Drug -> drug_name - Indication -> indication - TA_ID -> ta_id - Search_Term -> search_term - SNOMEDCode -> snomed_code (cleaned: removes trailing .0) - SNOMEDDescription -> snomed_description - CleanedDrugName -> cleaned_drug_name - PrimaryDirectorate -> primary_directorate - AllDirectorates -> all_directorates -""" - -from pathlib import Path -from typing import Optional - -from core.logging_config import get_logger -from data_processing.database import DatabaseManager -from data_processing.reference_data import MigrationResult, _read_csv_with_fallback_encoding - -logger = get_logger(__name__) - -DEFAULT_CSV_PATH = Path("./data/drug_snomed_mapping_enriched.csv") - - -def clean_snomed_code(snomed_code: str) -> str: - """ - Clean SNOMED code by removing trailing .0 suffix and handling scientific notation. - - The enriched CSV has SNOMED codes that may be in decimal notation (e.g., "156370009.0") - or scientific notation (e.g., "1.0629311000119108e+16") due to pandas/Excel export. - These need to be converted to clean integer strings. - - Args: - snomed_code: Raw SNOMED code from CSV. - - Returns: - Cleaned SNOMED code as string (e.g., "156370009" or "10629311000119108"). - """ - if not snomed_code: - return "" - - code = snomed_code.strip() - - # Handle scientific notation (e.g., "1.0629311000119108e+16") - if 'e' in code.lower(): - try: - # Convert to float first, then to int, then to string - # Using int() directly on the float preserves precision for SNOMED codes - value = float(code) - # Check if it's a whole number (no decimal part) - if value == int(value): - return str(int(value)) - else: - # Has decimal part - return as cleaned float - return str(value).replace('.0', '') - except (ValueError, OverflowError): - # If conversion fails, return as-is but cleaned - return code - - # Remove trailing .0 if present (for non-scientific notation) - if code.endswith(".0"): - code = code[:-2] - - return code - - -def migrate_drug_snomed_mapping( - db_manager: Optional[DatabaseManager] = None, - csv_path: Optional[Path] = None -) -> MigrationResult: - """ - Migrate drug SNOMED mappings from CSV to SQLite ref_drug_snomed_mapping table. - - Source file format (with header): - Drug,Indication,TA_ID,Search_Term,SNOMEDCode,SNOMEDDescription, - CleanedDrugName,PrimaryDirectorate,AllDirectorates - - Example rows: - ABATACEPT,Psoriatic arthritis after DMARDs,TA568,psoriatic arthritis, - 156370009.0,Psoriatic arthritis,ABATACEPT,RHEUMATOLOGY,RHEUMATOLOGY|DERMATOLOGY - - Args: - db_manager: DatabaseManager instance. Uses default if not provided. - csv_path: Path to the CSV file. Defaults to data/drug_snomed_mapping_enriched.csv. - - Returns: - MigrationResult with statistics about the migration. - """ - if db_manager is None: - db_manager = DatabaseManager() - if csv_path is None: - csv_path = DEFAULT_CSV_PATH - - table_name = "ref_drug_snomed_mapping" - - logger.info(f"Migrating drug SNOMED mappings from {csv_path} to {table_name}") - - if not csv_path.exists(): - error_msg = f"Source file not found: {csv_path}" - logger.error(error_msg) - return MigrationResult( - table_name=table_name, - source_file=str(csv_path), - rows_read=0, - rows_inserted=0, - rows_skipped=0, - success=False, - error_message=error_msg - ) - - rows_read = 0 - rows_inserted = 0 - rows_skipped = 0 - - try: - with db_manager.get_transaction() as conn: - rows = _read_csv_with_fallback_encoding(csv_path) - - for i, row in enumerate(rows): - # Skip header row - if i == 0 and len(row) >= 5 and row[0].strip().lower() == "drug": - logger.debug("Skipping header row") - continue - - rows_read += 1 - - # Validate row format (need at least: Drug, Indication, TA_ID, Search_Term, SNOMEDCode) - if len(row) < 5: - logger.warning(f"Skipping malformed row {rows_read}: {row}") - rows_skipped += 1 - continue - - drug_name = row[0].strip() - indication = row[1].strip() - ta_id = row[2].strip() if len(row) > 2 else "" - search_term = row[3].strip() - snomed_code_raw = row[4].strip() if len(row) > 4 else "" - snomed_description = row[5].strip() if len(row) > 5 else "" - cleaned_drug_name = row[6].strip() if len(row) > 6 else drug_name.upper() - primary_directorate = row[7].strip() if len(row) > 7 else "" - all_directorates = row[8].strip() if len(row) > 8 else "" - - # Skip if required fields are empty - if not drug_name or not indication or not search_term or not snomed_code_raw: - logger.warning(f"Skipping row {rows_read} with empty required fields") - rows_skipped += 1 - continue - - # Clean SNOMED code (remove trailing .0) - snomed_code = clean_snomed_code(snomed_code_raw) - - if not snomed_code: - logger.warning(f"Skipping row {rows_read} with invalid SNOMED code: {snomed_code_raw}") - rows_skipped += 1 - continue - - cursor = conn.execute( - """ - INSERT OR IGNORE INTO ref_drug_snomed_mapping - (drug_name, indication, ta_id, search_term, snomed_code, snomed_description, - cleaned_drug_name, primary_directorate, all_directorates) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - drug_name, - indication, - ta_id, - search_term, - snomed_code, - snomed_description, - cleaned_drug_name, - primary_directorate, - all_directorates, - ) - ) - - if cursor.rowcount > 0: - rows_inserted += 1 - else: - rows_skipped += 1 - - # Log progress every 10000 rows - if rows_read % 10000 == 0: - logger.info(f"Processed {rows_read} rows, inserted {rows_inserted}") - - logger.info( - f"Drug SNOMED mapping migration complete: {rows_read} rows read, " - f"{rows_inserted} inserted, {rows_skipped} skipped" - ) - - return MigrationResult( - table_name=table_name, - source_file=str(csv_path), - rows_read=rows_read, - rows_inserted=rows_inserted, - rows_skipped=rows_skipped, - success=True - ) - - except Exception as e: - error_msg = f"Migration failed: {e}" - logger.error(error_msg) - return MigrationResult( - table_name=table_name, - source_file=str(csv_path), - rows_read=rows_read, - rows_inserted=0, - rows_skipped=0, - success=False, - error_message=error_msg - ) - - -def get_drug_snomed_mapping_counts(db_manager: Optional[DatabaseManager] = None) -> dict: - """ - Get statistics about the ref_drug_snomed_mapping table. - - Args: - db_manager: DatabaseManager instance. Uses default if not provided. - - Returns: - Dictionary with: - - total_mappings: Total rows in table - - unique_drugs: Count of distinct drug names - - unique_search_terms: Count of distinct search terms - - unique_snomed_codes: Count of distinct SNOMED codes - - unique_indications: Count of distinct indications - """ - if db_manager is None: - db_manager = DatabaseManager() - - with db_manager.get_connection() as conn: - cursor = conn.execute("SELECT COUNT(*) FROM ref_drug_snomed_mapping") - total = cursor.fetchone()[0] - - cursor = conn.execute("SELECT COUNT(DISTINCT drug_name) FROM ref_drug_snomed_mapping") - unique_drugs = cursor.fetchone()[0] - - cursor = conn.execute("SELECT COUNT(DISTINCT search_term) FROM ref_drug_snomed_mapping") - unique_search_terms = cursor.fetchone()[0] - - cursor = conn.execute("SELECT COUNT(DISTINCT snomed_code) FROM ref_drug_snomed_mapping") - unique_snomed_codes = cursor.fetchone()[0] - - cursor = conn.execute("SELECT COUNT(DISTINCT indication) FROM ref_drug_snomed_mapping") - unique_indications = cursor.fetchone()[0] - - return { - "total_mappings": total, - "unique_drugs": unique_drugs, - "unique_search_terms": unique_search_terms, - "unique_snomed_codes": unique_snomed_codes, - "unique_indications": unique_indications, - } - - -def verify_drug_snomed_mapping_migration( - db_manager: Optional[DatabaseManager] = None, - csv_path: Optional[Path] = None -) -> tuple[bool, str]: - """ - Verify that drug SNOMED mappings were migrated correctly. - - Checks: - - Row count is reasonable (163K+ expected) - - Unique search terms is reasonable (187 expected) - - Sample lookups return expected values - - Args: - db_manager: DatabaseManager instance. Uses default if not provided. - csv_path: Path to the CSV file. Defaults to data/drug_snomed_mapping_enriched.csv. - - Returns: - Tuple of (success: bool, message: str) - """ - if db_manager is None: - db_manager = DatabaseManager() - if csv_path is None: - csv_path = DEFAULT_CSV_PATH - - stats = get_drug_snomed_mapping_counts(db_manager) - - # Basic sanity checks - if stats["total_mappings"] < 100000: - return False, f"Too few rows: expected 163K+, got {stats['total_mappings']}" - - if stats["unique_search_terms"] < 100: - return False, f"Too few search terms: expected ~187, got {stats['unique_search_terms']}" - - # Sample lookup verification - with db_manager.get_connection() as conn: - # Check that ABATACEPT exists (from sample data) - cursor = conn.execute( - "SELECT COUNT(*) FROM ref_drug_snomed_mapping WHERE drug_name = 'ABATACEPT'" - ) - abatacept_count = cursor.fetchone()[0] - if abatacept_count == 0: - return False, "Sample drug ABATACEPT not found in table" - - # Check that SNOMED codes were cleaned (no .0 suffix) - cursor = conn.execute( - "SELECT COUNT(*) FROM ref_drug_snomed_mapping WHERE snomed_code LIKE '%.0'" - ) - dirty_codes = cursor.fetchone()[0] - if dirty_codes > 0: - return False, f"Found {dirty_codes} SNOMED codes with uncleaned .0 suffix" - - return True, ( - f"Verified {stats['total_mappings']:,} mappings: " - f"{stats['unique_drugs']} drugs, " - f"{stats['unique_search_terms']} search terms, " - f"{stats['unique_snomed_codes']:,} SNOMED codes" - ) - - -def main(): - """CLI entry point for loading SNOMED mapping data.""" - import argparse - - parser = argparse.ArgumentParser( - description="Load drug SNOMED mapping data into SQLite database" - ) - parser.add_argument( - "--csv", - type=Path, - default=DEFAULT_CSV_PATH, - help=f"Path to CSV file (default: {DEFAULT_CSV_PATH})" - ) - parser.add_argument( - "--verify-only", - action="store_true", - help="Only verify existing data, don't migrate" - ) - parser.add_argument( - "-v", "--verbose", - action="store_true", - help="Enable verbose logging" - ) - - args = parser.parse_args() - - # Configure logging - import logging - if args.verbose: - logging.basicConfig(level=logging.DEBUG) - else: - logging.basicConfig(level=logging.INFO) - - if args.verify_only: - print("Verifying existing data...") - success, message = verify_drug_snomed_mapping_migration(csv_path=args.csv) - if success: - print(f"[OK] Verification passed: {message}") - else: - print(f"[FAILED] Verification failed: {message}") - return 0 if success else 1 - - # Run migration - print(f"Loading SNOMED mapping from {args.csv}...") - result = migrate_drug_snomed_mapping(csv_path=args.csv) - - if result.success: - print(f"[OK] {result}") - - # Show statistics - stats = get_drug_snomed_mapping_counts() - print(f"\nTable statistics:") - print(f" Total mappings: {stats['total_mappings']:,}") - print(f" Unique drugs: {stats['unique_drugs']}") - print(f" Unique search terms: {stats['unique_search_terms']}") - print(f" Unique SNOMED codes: {stats['unique_snomed_codes']:,}") - print(f" Unique indications: {stats['unique_indications']}") - - # Verify - success, message = verify_drug_snomed_mapping_migration(csv_path=args.csv) - if success: - print(f"\n[OK] Verification: {message}") - else: - print(f"\n[WARNING] Verification: {message}") - return 1 - else: - print(f"[FAILED] {result}") - return 1 - - return 0 - - -if __name__ == "__main__": - exit(main()) diff --git a/data_processing/loader.py b/data_processing/loader.py index 7263447..ec19b1d 100644 --- a/data_processing/loader.py +++ b/data_processing/loader.py @@ -11,7 +11,6 @@ The DataLoader ABC defines the contract for all loader implementations. from abc import ABC, abstractmethod from dataclasses import dataclass, field -from datetime import date from pathlib import Path from typing import Optional @@ -29,7 +28,7 @@ class LoadResult: Attributes: df: The loaded DataFrame with processed patient intervention data - source: Description of the data source (e.g., "csv:/path/to/file.csv", "sqlite:fact_interventions") + source: Description of the data source (e.g., "file:/path/to/file.csv") row_count: Number of rows loaded columns: List of column names in the DataFrame load_time_seconds: Time taken to load the data @@ -224,150 +223,6 @@ class FileDataLoader(DataLoader): ) -class SQLiteDataLoader(DataLoader): - """Loads data from SQLite fact_interventions table. - - This provides faster loading by reading pre-processed data from SQLite - instead of re-processing CSV files each time. - - The SQLite database must have been populated by the migration scripts. - - Args: - db_path: Path to the SQLite database (uses default if None) - date_range: Optional tuple of (start_date, end_date) to filter data - trusts: Optional list of trust names to filter - drugs: Optional list of drug names to filter - directories: Optional list of directories to filter - """ - - def __init__( - self, - db_path: Optional[Path | str] = None, - date_range: Optional[tuple[date, date]] = None, - trusts: Optional[list[str]] = None, - drugs: Optional[list[str]] = None, - directories: Optional[list[str]] = None, - ): - from data_processing.database import default_db_config - - self.db_path = Path(db_path) if db_path else Path(default_db_config.db_path) - self.date_range = date_range - self.trusts = trusts - self.drugs = drugs - self.directories = directories - - def validate_source(self) -> tuple[bool, str]: - """Check if the database exists and has the fact_interventions table.""" - if not self.db_path.exists(): - return False, f"Database not found: {self.db_path}" - - # Check if fact_interventions table exists - from data_processing.database import DatabaseManager, DatabaseConfig - - config = DatabaseConfig(db_path=self.db_path) - manager = DatabaseManager(config) - - if not manager.table_exists("fact_interventions"): - return False, "fact_interventions table not found in database" - - count = manager.get_table_count("fact_interventions") - if count == 0: - return False, "fact_interventions table is empty" - - return True, f"OK ({count:,} rows available)" - - @property - def source_description(self) -> str: - return f"sqlite:{self.db_path}" - - def load(self) -> LoadResult: - """Load data from SQLite fact_interventions table. - - Maps SQLite column names to the expected DataFrame column names. - Applies optional filters for date range, trusts, drugs, directories. - """ - import time - from data_processing.database import DatabaseManager, DatabaseConfig - - start_time = time.time() - - # Validate source - is_valid, msg = self.validate_source() - if not is_valid: - raise FileNotFoundError(msg) - - logger.info(f"Loading data from SQLite: {self.db_path}") - - # Build query with optional filters - query = """ - SELECT - upid AS "UPID", - provider_code AS "Provider Code", - person_key AS "PersonKey", - drug_name_std AS "Drug Name", - intervention_date AS "Intervention Date", - price_actual AS "Price Actual", - org_name AS "OrganisationName", - directory AS "Directory", - treatment_function_code AS "Treatment Function Code", - additional_detail_1 AS "Additional Detail 1", - additional_detail_2 AS "Additional Detail 2", - additional_detail_3 AS "Additional Detail 3", - additional_detail_4 AS "Additional Detail 4", - additional_detail_5 AS "Additional Detail 5" - FROM fact_interventions - WHERE 1=1 - """ - params = [] - - if self.date_range: - start, end = self.date_range - query += " AND intervention_date >= ? AND intervention_date < ?" - params.extend([str(start), str(end)]) - - if self.trusts: - placeholders = ','.join('?' * len(self.trusts)) - query += f" AND org_name IN ({placeholders})" - params.extend(self.trusts) - - if self.drugs: - placeholders = ','.join('?' * len(self.drugs)) - query += f" AND drug_name_std IN ({placeholders})" - params.extend(self.drugs) - - if self.directories: - placeholders = ','.join('?' * len(self.directories)) - query += f" AND directory IN ({placeholders})" - params.extend(self.directories) - - # Execute query - config = DatabaseConfig(db_path=self.db_path) - manager = DatabaseManager(config) - - with manager.get_connection() as conn: - df = pd.read_sql_query(query, conn, params=params) - - # Convert intervention_date to datetime - df['Intervention Date'] = pd.to_datetime(df['Intervention Date']) - - logger.info(f"Loaded {len(df)} rows from SQLite") - - # Validate result - is_valid, missing = self.validate_dataframe(df) - if not is_valid: - raise ValueError(f"SQLite data missing required columns: {missing}") - - load_time = time.time() - start_time - logger.info(f"SQLite data loading complete. {len(df)} rows in {load_time:.2f}s") - - return LoadResult( - df=df, - source=self.source_description, - row_count=len(df), - load_time_seconds=load_time, - ) - - def get_loader( source: str | Path, paths: Optional[PathConfig] = None, @@ -376,7 +231,7 @@ def get_loader( """Factory function to create the appropriate DataLoader. Args: - source: Either a file path (CSV/Parquet) or "sqlite" for database + source: File path (CSV/Parquet) paths: PathConfig for reference data (used by FileDataLoader) **kwargs: Additional arguments passed to the loader constructor @@ -386,14 +241,6 @@ def get_loader( Examples: >>> loader = get_loader("data/activity.csv") >>> loader = get_loader("data/activity.parquet") - >>> loader = get_loader("sqlite") - >>> loader = get_loader("sqlite", date_range=(date(2024, 1, 1), date(2024, 12, 31))) """ - source_str = str(source).lower() - - if source_str == "sqlite": - return SQLiteDataLoader(**kwargs) - - # Assume it's a file path path = Path(source) return FileDataLoader(file_path=path, paths=paths) diff --git a/data_processing/migrate.py b/data_processing/migrate.py index 45cf908..ff87489 100644 --- a/data_processing/migrate.py +++ b/data_processing/migrate.py @@ -35,6 +35,7 @@ from data_processing.schema import ( verify_all_tables_exist, get_all_table_counts, migrate_pathway_nodes_chart_type, + migrate_refresh_log_source_row_count, ) from data_processing.reference_data import ( MigrationResult, @@ -49,12 +50,6 @@ from data_processing.reference_data import ( verify_drug_directory_map_migration, verify_drug_indication_clusters_migration, ) -from data_processing.patient_data import ( - load_patient_data, - refresh_patient_treatment_summary, - get_patient_data_stats, - verify_mv_consistency, -) logger = get_logger(__name__) @@ -67,9 +62,8 @@ def initialize_database( """ Initialize the database with all required tables. - Creates all tables defined in the schema (reference tables, fact tables, - materialized views, and file tracking tables). Uses IF NOT EXISTS so - safe to run multiple times. + Creates all tables defined in the schema (reference tables and pathway + tables). Uses IF NOT EXISTS so safe to run multiple times. Args: db_manager: DatabaseManager instance. Uses default if not provided. @@ -122,6 +116,14 @@ def initialize_database( else: logger.error(f"pathway_nodes migration failed: {msg}") return False + + # Add source_row_count column to pathway_refresh_log if it doesn't exist + success, msg = migrate_refresh_log_source_row_count(conn) + if success: + logger.info(f"pathway_refresh_log migration: {msg}") + else: + logger.error(f"pathway_refresh_log migration failed: {msg}") + return False except Exception as e: logger.error(f"Migration failed: {e}") return False @@ -274,107 +276,6 @@ def create_progress_reporter(description: str = "Loading", width: int = 40): return report_progress -def load_patient_data_cli( - file_path: Path, - db_manager: Optional[DatabaseManager] = None, - paths: Optional[PathConfig] = None, - force: bool = False, - refresh_mv: bool = True -) -> bool: - """ - Load patient data from file with CLI progress reporting. - - Args: - file_path: Path to CSV or Parquet file. - db_manager: DatabaseManager instance. Uses default if not provided. - paths: PathConfig for reference data. Uses default if not provided. - force: If True, re-process even if file hash matches. - refresh_mv: If True, refresh the materialized view after loading. - - Returns: - True if loading succeeded, False otherwise. - """ - if db_manager is None: - db_manager = DatabaseManager() - if paths is None: - paths = default_paths - - print(f"\n=== Loading Patient Data ===\n") - print(f"File: {file_path}") - - # Check file exists - if not file_path.exists(): - print(f"ERROR: File not found: {file_path}") - return False - - # Calculate and display file info - file_size_mb = file_path.stat().st_size / (1024 * 1024) - print(f"Size: {file_size_mb:.1f} MB") - print() - - # Create progress callback - progress_callback = create_progress_reporter("Loading rows", width=40) - - # Load the data - result = load_patient_data( - file_path=file_path, - db_manager=db_manager, - paths=paths, - batch_size=5000, - force=force, - progress_callback=progress_callback - ) - - # Print result - print() - if result.was_already_processed: - print("File already processed (same hash). Skipping.") - print(f"Use --force to re-process.") - elif result.success: - print(f"Loaded {result.rows_inserted:,} rows in {result.load_time_seconds:.1f}s") - if result.rows_skipped > 0: - print(f"Skipped {result.rows_skipped:,} rows (missing UPID or date)") - else: - print(f"FAILED: {result.error_message}") - return False - - # Refresh materialized view if requested - if refresh_mv and result.success and not result.was_already_processed: - print() - print("Refreshing materialized view...") - mv_progress = create_progress_reporter("Processing patients", width=40) - mv_result = refresh_patient_treatment_summary( - db_manager=db_manager, - progress_callback=mv_progress - ) - - if mv_result.success: - print(f"MV refreshed: {mv_result.patients_processed:,} patients in {mv_result.refresh_time_seconds:.1f}s") - - # Verify consistency - consistent, msg = verify_mv_consistency(db_manager) - if consistent: - print(f"MV verification: OK") - else: - print(f"MV verification: FAILED - {msg}") - else: - print(f"MV refresh FAILED: {mv_result.error_message}") - - # Print summary statistics - print() - print("=== Patient Data Summary ===") - stats = get_patient_data_stats(db_manager) - print(f" Total rows: {stats['total_rows']:,}") - print(f" Unique patients: {stats['unique_patients']:,}") - print(f" Unique drugs: {stats['unique_drugs']:,}") - print(f" Unique organizations: {stats['unique_organizations']:,}") - if stats['date_range'][0] and stats['date_range'][1]: - print(f" Date range: {stats['date_range'][0]} to {stats['date_range'][1]}") - print() - - return result.success - - def get_database_status(db_manager: Optional[DatabaseManager] = None) -> dict: """ Get the current status of the database. @@ -452,8 +353,6 @@ Examples: python -m data_processing.migrate --drop-existing # Reset database python -m data_processing.migrate --reference-data # Migrate reference data python -m data_processing.migrate --reference-data --verify # With verification - python -m data_processing.migrate --load-patient-data data.parquet # Load patient data - python -m data_processing.migrate --load-patient-data data.csv --force # Force reload python -m data_processing.migrate --db-path ./data/test.db # Custom path """ ) @@ -493,23 +392,6 @@ Examples: action="store_true", help="Enable verbose logging" ) - parser.add_argument( - "--load-patient-data", - type=Path, - metavar="FILE", - help="Load patient data from CSV or Parquet file with progress reporting" - ) - parser.add_argument( - "--force", - action="store_true", - help="Force re-processing even if file hash matches (use with --load-patient-data)" - ) - parser.add_argument( - "--no-refresh-mv", - action="store_true", - help="Skip materialized view refresh after loading (use with --load-patient-data)" - ) - args = parser.parse_args() # Set up logging @@ -562,32 +444,6 @@ Examples: print("Reference data migration completed with errors. Check logs for details.") return 1 - # Handle --load-patient-data (load patient data from CSV/Parquet) - if args.load_patient_data: - # Ensure database exists with tables first - if not db_manager.exists: - print("Database does not exist. Initializing schema first...") - success = initialize_database(db_manager=db_manager) - if not success: - print("\nDatabase initialization failed. Check logs for details.") - return 1 - - # Load patient data with progress reporting - success = load_patient_data_cli( - file_path=args.load_patient_data, - db_manager=db_manager, - paths=default_paths, - force=args.force, - refresh_mv=not args.no_refresh_mv - ) - - if success: - print("Patient data load completed successfully.") - return 0 - else: - print("Patient data load failed. Check logs for details.") - return 1 - # Run schema migration (default behavior) success = initialize_database( db_manager=db_manager, diff --git a/data_processing/patient_data.py b/data_processing/patient_data.py deleted file mode 100644 index 64b1b02..0000000 --- a/data_processing/patient_data.py +++ /dev/null @@ -1,890 +0,0 @@ -""" -Patient data migration functions for NHS High-Cost Drug Patient Pathway Analysis Tool. - -Provides functions to load patient intervention data from CSV/Parquet files -into the SQLite fact_interventions table. Supports: -- Batch processing for large files -- File hash tracking for incremental updates -- Progress reporting during loading -""" - -import hashlib -import os -import sqlite3 -import time -from dataclasses import dataclass -from datetime import datetime -from pathlib import Path -from typing import Callable, Optional - -import pandas as pd - -from core import PathConfig, default_paths -from core.logging_config import get_logger -from data_processing.database import DatabaseManager - -logger = get_logger(__name__) - - -@dataclass -class PatientDataLoadResult: - """Results from a patient data load operation.""" - file_path: str - file_hash: str - rows_read: int - rows_inserted: int - rows_skipped: int - success: bool - error_message: Optional[str] = None - load_time_seconds: float = 0.0 - was_already_processed: bool = False - - def __str__(self) -> str: - if self.was_already_processed: - return f"{self.file_path}: Already processed (same hash)" - elif self.success: - return ( - f"{self.file_path}: Loaded {self.rows_inserted:,} rows " - f"in {self.load_time_seconds:.1f}s" - ) - else: - return f"{self.file_path}: FAILED - {self.error_message}" - - -def calculate_file_hash(file_path: Path) -> str: - """ - Calculate SHA256 hash of a file. - - Uses chunked reading to handle large files efficiently. - - Args: - file_path: Path to the file. - - Returns: - Hex string of SHA256 hash. - """ - sha256_hash = hashlib.sha256() - with open(file_path, "rb") as f: - for chunk in iter(lambda: f.read(8192), b""): - sha256_hash.update(chunk) - return sha256_hash.hexdigest() - - -def check_file_processed( - conn: sqlite3.Connection, - file_path: str, - file_hash: str -) -> tuple[bool, Optional[str]]: - """ - Check if a file has already been processed with the same hash. - - Args: - conn: Database connection. - file_path: Full path to the file. - file_hash: SHA256 hash of the file. - - Returns: - Tuple of (is_processed, old_hash). - - If is_processed is True and old_hash == file_hash, file is unchanged. - - If is_processed is True and old_hash != file_hash, file has changed. - - If is_processed is False, file is new. - """ - cursor = conn.execute( - "SELECT file_hash, status FROM processed_files WHERE file_path = ?", - (file_path,) - ) - result = cursor.fetchone() - - if result is None: - return False, None - - old_hash = result["file_hash"] - status = result["status"] - - # Only consider it processed if status is success and hash matches - if status == "success" and old_hash == file_hash: - return True, old_hash - - return False, old_hash - - -def record_file_processing_start( - conn: sqlite3.Connection, - file_path: str, - file_hash: str, - file_size: int, - file_modified: datetime -) -> None: - """ - Record that we're starting to process a file. - - Args: - conn: Database connection. - file_path: Full path to the file. - file_hash: SHA256 hash of the file. - file_size: File size in bytes. - file_modified: File modification timestamp. - """ - file_name = Path(file_path).name - now = datetime.now().isoformat() - - conn.execute(""" - INSERT INTO processed_files ( - file_path, file_name, file_hash, file_size_bytes, - file_modified_at, status, first_processed_at, last_processed_at - ) VALUES (?, ?, ?, ?, ?, 'processing', ?, ?) - ON CONFLICT(file_path) DO UPDATE SET - file_hash = excluded.file_hash, - file_size_bytes = excluded.file_size_bytes, - file_modified_at = excluded.file_modified_at, - status = 'processing', - last_processed_at = excluded.last_processed_at, - error_message = NULL - """, (file_path, file_name, file_hash, file_size, file_modified.isoformat(), now, now)) - - -def record_file_processing_complete( - conn: sqlite3.Connection, - file_path: str, - row_count: int, - duration_seconds: float, - success: bool, - error_message: Optional[str] = None -) -> None: - """ - Record that file processing has completed. - - Args: - conn: Database connection. - file_path: Full path to the file. - row_count: Number of rows processed. - duration_seconds: Time taken to process. - success: Whether processing was successful. - error_message: Error message if failed. - """ - status = "success" if success else "error" - - conn.execute(""" - UPDATE processed_files - SET status = ?, - row_count = ?, - processing_duration_seconds = ?, - error_message = ?, - last_processed_at = ? - WHERE file_path = ? - """, (status, row_count, duration_seconds, error_message, datetime.now().isoformat(), file_path)) - - -def load_dataframe_to_sqlite( - df: pd.DataFrame, - conn: sqlite3.Connection, - source_file: str, - batch_size: int = 5000, - progress_callback: Optional[Callable[[int, int], None]] = None -) -> int: - """ - Load a processed DataFrame into fact_interventions table. - - Args: - df: Processed DataFrame with required columns (from FileDataLoader). - conn: Database connection. - source_file: Source file path for tracking. - batch_size: Number of rows to insert per batch. - progress_callback: Optional callback(rows_inserted, total_rows) for progress updates. - - Returns: - Number of rows inserted. - """ - # Store the original drug names before processing (for rows where mapping doesn't exist) - # The drug_names() transformation sets Drug Name to NULL when no mapping exists. - # We need to preserve the original for those cases. - - # Insert SQL columns - always include drug_name_raw - insert_columns = [ - "upid", "provider_code", "person_key", - "drug_name_raw", "drug_name_std", - "intervention_date", "price_actual", - "org_name", "directory", - "treatment_function_code", - "additional_detail_1", "additional_detail_2", "additional_detail_3", - "additional_detail_4", "additional_detail_5", - "source_file" - ] - placeholders = ",".join(["?"] * len(insert_columns)) - insert_sql = f""" - INSERT INTO fact_interventions ({",".join(insert_columns)}) - VALUES ({placeholders}) - """ - - rows_inserted = 0 - rows_skipped = 0 - total_rows = len(df) - - # Process in batches - for batch_start in range(0, total_rows, batch_size): - batch_end = min(batch_start + batch_size, total_rows) - batch_df = df.iloc[batch_start:batch_end] - - # Prepare batch data - batch_data = [] - for _, row in batch_df.iterrows(): - # Skip rows missing required fields - if pd.isna(row.get("UPID")) or pd.isna(row.get("Intervention Date")): - rows_skipped += 1 - continue - # Get drug names - raw and standardized - drug_name_raw = row.get("Drug Name Raw") if "Drug Name Raw" in df.columns else None - drug_name_std = row.get("Drug Name") - - # If drug_name_std is NULL, use the raw drug name (uppercase) - # This handles cases where the drug isn't in the drugnames.csv mapping - if pd.isna(drug_name_std): - if drug_name_raw is not None and not pd.isna(drug_name_raw): - drug_name_std = str(drug_name_raw).upper().strip() - else: - drug_name_std = "UNKNOWN" - - # Also clean up raw drug name for storage - if drug_name_raw is not None and not pd.isna(drug_name_raw): - drug_name_raw = str(drug_name_raw).strip() - - # Get other values with null handling - def get_value(col_name): - if col_name not in df.columns: - return None - val = row[col_name] - if pd.isna(val): - return None - elif hasattr(val, "strftime"): - return val.strftime("%Y-%m-%d") - return val - - row_data = ( - get_value("UPID"), - get_value("Provider Code"), - get_value("PersonKey"), - drug_name_raw, - drug_name_std, - get_value("Intervention Date"), - get_value("Price Actual") or 0, - get_value("OrganisationName"), - get_value("Directory"), - get_value("Treatment Function Code"), - get_value("Additional Detail 1"), - get_value("Additional Detail 2"), - get_value("Additional Detail 3"), - get_value("Additional Detail 4"), - get_value("Additional Detail 5"), - source_file - ) - batch_data.append(row_data) - - # Execute batch insert - conn.executemany(insert_sql, batch_data) - rows_inserted += len(batch_data) - - # Report progress - if progress_callback: - progress_callback(rows_inserted, total_rows) - - if rows_skipped > 0: - logger.info(f"Skipped {rows_skipped:,} rows with missing UPID or Intervention Date") - - return rows_inserted - - -def delete_file_data(conn: sqlite3.Connection, source_file: str) -> int: - """ - Delete all data from a specific source file. - - Used when re-processing a changed file. - - Args: - conn: Database connection. - source_file: Source file path. - - Returns: - Number of rows deleted. - """ - cursor = conn.execute( - "DELETE FROM fact_interventions WHERE source_file = ?", - (source_file,) - ) - return cursor.rowcount - - -def load_patient_data( - file_path: Path | str, - db_manager: Optional[DatabaseManager] = None, - paths: Optional[PathConfig] = None, - batch_size: int = 5000, - force: bool = False, - progress_callback: Optional[Callable[[int, int], None]] = None -) -> PatientDataLoadResult: - """ - Load patient data from CSV/Parquet file into fact_interventions table. - - This is the main entry point for loading patient data. It: - 1. Calculates file hash to detect changes - 2. Checks if file was already processed (skip if unchanged) - 3. Loads and transforms data using FileDataLoader - 4. Inserts data into SQLite in batches - 5. Records processing status in processed_files table - - Args: - file_path: Path to CSV or Parquet file. - db_manager: DatabaseManager instance. Uses default if not provided. - paths: PathConfig for reference data. Uses default if not provided. - batch_size: Number of rows to insert per batch (default: 5000). - force: If True, re-process even if file hash matches. - progress_callback: Optional callback(rows_inserted, total_rows) for progress. - - Returns: - PatientDataLoadResult with loading statistics. - """ - if db_manager is None: - db_manager = DatabaseManager() - if paths is None: - paths = default_paths - - file_path = Path(file_path) - file_path_str = str(file_path.absolute()) - - logger.info(f"Starting patient data load from {file_path}") - start_time = time.time() - - # Check file exists - if not file_path.exists(): - error_msg = f"File not found: {file_path}" - logger.error(error_msg) - return PatientDataLoadResult( - file_path=file_path_str, - file_hash="", - rows_read=0, - rows_inserted=0, - rows_skipped=0, - success=False, - error_message=error_msg - ) - - # Calculate file hash - logger.info("Calculating file hash...") - file_hash = calculate_file_hash(file_path) - file_size = file_path.stat().st_size - file_modified = datetime.fromtimestamp(file_path.stat().st_mtime) - - logger.info(f"File hash: {file_hash[:16]}... Size: {file_size:,} bytes") - - # Check if already processed - if not force: - with db_manager.get_connection() as conn: - is_processed, old_hash = check_file_processed(conn, file_path_str, file_hash) - if is_processed: - logger.info(f"File already processed with same hash, skipping") - return PatientDataLoadResult( - file_path=file_path_str, - file_hash=file_hash, - rows_read=0, - rows_inserted=0, - rows_skipped=0, - success=True, - was_already_processed=True - ) - elif old_hash is not None: - logger.info(f"File hash changed, will re-process (old: {old_hash[:16]}...)") - - try: - # Use FileDataLoader to load and transform data - from data_processing.loader import FileDataLoader - - loader = FileDataLoader(file_path, paths) - logger.info("Loading and transforming data...") - result = loader.load() - df = result.df - rows_read = result.row_count - - logger.info(f"Loaded {rows_read:,} rows, starting SQLite insert...") - - # Load into SQLite - with db_manager.get_transaction() as conn: - # Record that we're starting - record_file_processing_start(conn, file_path_str, file_hash, file_size, file_modified) - - # Delete any existing data from this file (for re-processing) - deleted = delete_file_data(conn, file_path_str) - if deleted > 0: - logger.info(f"Deleted {deleted:,} existing rows from previous load") - - # Insert new data - rows_inserted = load_dataframe_to_sqlite( - df, conn, file_path_str, batch_size, progress_callback - ) - - # Record success - load_time = time.time() - start_time - record_file_processing_complete( - conn, file_path_str, rows_inserted, load_time, True - ) - - logger.info(f"Successfully loaded {rows_inserted:,} rows in {load_time:.1f}s") - - return PatientDataLoadResult( - file_path=file_path_str, - file_hash=file_hash, - rows_read=rows_read, - rows_inserted=rows_inserted, - rows_skipped=rows_read - rows_inserted, - success=True, - load_time_seconds=load_time - ) - - except Exception as e: - load_time = time.time() - start_time - error_msg = str(e) - logger.error(f"Failed to load patient data: {error_msg}") - - # Record failure - try: - with db_manager.get_connection() as conn: - record_file_processing_complete( - conn, file_path_str, 0, load_time, False, error_msg - ) - except Exception: - pass # Don't fail on failure to record failure - - return PatientDataLoadResult( - file_path=file_path_str, - file_hash=file_hash if 'file_hash' in dir() else "", - rows_read=0, - rows_inserted=0, - rows_skipped=0, - success=False, - error_message=error_msg, - load_time_seconds=load_time - ) - - -def get_patient_data_stats(db_manager: Optional[DatabaseManager] = None) -> dict: - """ - Get statistics about patient data in fact_interventions. - - Returns: - Dictionary with statistics about the loaded data. - """ - if db_manager is None: - db_manager = DatabaseManager() - - stats = {} - - with db_manager.get_connection() as conn: - # Total rows - cursor = conn.execute("SELECT COUNT(*) FROM fact_interventions") - stats["total_rows"] = cursor.fetchone()[0] - - # Unique patients - cursor = conn.execute("SELECT COUNT(DISTINCT upid) FROM fact_interventions") - stats["unique_patients"] = cursor.fetchone()[0] - - # Unique drugs - cursor = conn.execute("SELECT COUNT(DISTINCT drug_name_std) FROM fact_interventions") - stats["unique_drugs"] = cursor.fetchone()[0] - - # Unique organizations - cursor = conn.execute("SELECT COUNT(DISTINCT org_name) FROM fact_interventions") - stats["unique_organizations"] = cursor.fetchone()[0] - - # Date range - cursor = conn.execute(""" - SELECT MIN(intervention_date), MAX(intervention_date) - FROM fact_interventions - """) - result = cursor.fetchone() - stats["date_range"] = (result[0], result[1]) if result else (None, None) - - # Processed files - cursor = conn.execute(""" - SELECT COUNT(*), SUM(row_count) - FROM processed_files WHERE status = 'success' - """) - result = cursor.fetchone() - stats["processed_files"] = result[0] if result else 0 - stats["processed_rows"] = result[1] if result and result[1] else 0 - - return stats - - -def list_processed_files(db_manager: Optional[DatabaseManager] = None) -> list[dict]: - """ - List all processed files and their status. - - Returns: - List of dictionaries with file processing information. - """ - if db_manager is None: - db_manager = DatabaseManager() - - files = [] - - with db_manager.get_connection() as conn: - cursor = conn.execute(""" - SELECT file_path, file_name, file_hash, file_size_bytes, - row_count, status, error_message, - first_processed_at, last_processed_at, processing_duration_seconds - FROM processed_files - ORDER BY last_processed_at DESC - """) - - for row in cursor.fetchall(): - files.append({ - "file_path": row["file_path"], - "file_name": row["file_name"], - "file_hash": row["file_hash"], - "file_size_bytes": row["file_size_bytes"], - "row_count": row["row_count"], - "status": row["status"], - "error_message": row["error_message"], - "first_processed_at": row["first_processed_at"], - "last_processed_at": row["last_processed_at"], - "processing_duration_seconds": row["processing_duration_seconds"], - }) - - return files - - -# ============================================================================= -# Materialized View Refresh Functions -# ============================================================================= - -@dataclass -class MVRefreshResult: - """Results from refreshing the patient treatment summary materialized view.""" - patients_processed: int - rows_inserted: int - refresh_time_seconds: float - success: bool - error_message: Optional[str] = None - - def __str__(self) -> str: - if self.success: - return ( - f"Refreshed MV: {self.patients_processed:,} patients " - f"in {self.refresh_time_seconds:.1f}s" - ) - else: - return f"MV refresh FAILED: {self.error_message}" - - -def refresh_patient_treatment_summary( - db_manager: Optional[DatabaseManager] = None, - progress_callback: Optional[Callable[[int, int], None]] = None -) -> MVRefreshResult: - """ - Refresh the mv_patient_treatment_summary materialized view. - - This computes per-patient aggregations from fact_interventions: - - First/last seen dates - - Total cost, average cost per intervention - - Intervention count, unique drug count - - Drug sequence (chronological, pipe-separated) - - Drug counts, costs, and date ranges (as JSON) - - The MV is fully rebuilt (truncate and re-insert) for simplicity. - This typically takes 30-60 seconds for ~35,000 patients. - - Args: - db_manager: DatabaseManager instance. Uses default if not provided. - progress_callback: Optional callback(patients_done, total_patients). - - Returns: - MVRefreshResult with refresh statistics. - """ - if db_manager is None: - db_manager = DatabaseManager() - - logger.info("Starting materialized view refresh...") - start_time = time.time() - - try: - with db_manager.get_transaction() as conn: - # Step 1: Get total patient count for progress reporting - cursor = conn.execute("SELECT COUNT(DISTINCT upid) FROM fact_interventions") - total_patients = cursor.fetchone()[0] - logger.info(f"Processing {total_patients:,} unique patients") - - if total_patients == 0: - logger.warning("No patient data in fact_interventions, MV will be empty") - return MVRefreshResult( - patients_processed=0, - rows_inserted=0, - refresh_time_seconds=time.time() - start_time, - success=True - ) - - # Step 2: Clear existing MV data - conn.execute("DELETE FROM mv_patient_treatment_summary") - logger.info("Cleared existing MV data") - - # Step 3: Compute aggregations using SQL CTEs - # This is more efficient than processing row-by-row in Python - refresh_sql = """ - WITH patient_aggs AS ( - -- Basic aggregations per patient - SELECT - upid, - MIN(org_name) as org_name, - MIN(directory) as directory, - MIN(intervention_date) as first_seen_date, - MAX(intervention_date) as last_seen_date, - JULIANDAY(MAX(intervention_date)) - JULIANDAY(MIN(intervention_date)) as days_treated, - SUM(price_actual) as total_cost, - AVG(price_actual) as avg_cost_per_intervention, - COUNT(*) as intervention_count, - COUNT(DISTINCT drug_name_std) as unique_drug_count, - COUNT(*) as source_row_count - FROM fact_interventions - GROUP BY upid - ), - drug_sequences AS ( - -- Drug sequence per patient (chronological order, pipe-separated) - SELECT - upid, - GROUP_CONCAT(drug_name_std, '|') as drug_sequence - FROM ( - SELECT DISTINCT - upid, - drug_name_std, - MIN(intervention_date) as first_date - FROM fact_interventions - GROUP BY upid, drug_name_std - ORDER BY upid, first_date - ) - GROUP BY upid - ), - drug_counts AS ( - -- JSON object of drug counts per patient - SELECT - upid, - '{' || GROUP_CONCAT('"' || drug_name_std || '": ' || cnt, ', ') || '}' as drug_counts_json - FROM ( - SELECT - upid, - drug_name_std, - COUNT(*) as cnt - FROM fact_interventions - GROUP BY upid, drug_name_std - ) - GROUP BY upid - ), - drug_costs AS ( - -- JSON object of drug costs per patient - SELECT - upid, - '{' || GROUP_CONCAT('"' || drug_name_std || '": ' || ROUND(total_cost, 2), ', ') || '}' as drug_costs_json - FROM ( - SELECT - upid, - drug_name_std, - SUM(price_actual) as total_cost - FROM fact_interventions - GROUP BY upid, drug_name_std - ) - GROUP BY upid - ), - drug_dates AS ( - -- JSON object of drug date ranges per patient - SELECT - upid, - '{' || GROUP_CONCAT('"' || drug_name_std || '": {"first": "' || first_date || '", "last": "' || last_date || '"}', ', ') || '}' as drug_date_ranges_json - FROM ( - SELECT - upid, - drug_name_std, - MIN(intervention_date) as first_date, - MAX(intervention_date) as last_date - FROM fact_interventions - GROUP BY upid, drug_name_std - ) - GROUP BY upid - ) - INSERT INTO mv_patient_treatment_summary ( - upid, org_name, directory, - first_seen_date, last_seen_date, days_treated, - total_cost, avg_cost_per_intervention, - intervention_count, unique_drug_count, - drug_sequence, drug_counts_json, drug_costs_json, drug_date_ranges_json, - source_row_count, computed_at - ) - SELECT - pa.upid, - pa.org_name, - pa.directory, - pa.first_seen_date, - pa.last_seen_date, - CAST(pa.days_treated AS INTEGER), - pa.total_cost, - pa.avg_cost_per_intervention, - pa.intervention_count, - pa.unique_drug_count, - ds.drug_sequence, - dc.drug_counts_json, - dco.drug_costs_json, - dd.drug_date_ranges_json, - pa.source_row_count, - CURRENT_TIMESTAMP - FROM patient_aggs pa - LEFT JOIN drug_sequences ds ON pa.upid = ds.upid - LEFT JOIN drug_counts dc ON pa.upid = dc.upid - LEFT JOIN drug_costs dco ON pa.upid = dco.upid - LEFT JOIN drug_dates dd ON pa.upid = dd.upid - """ - - logger.info("Executing MV refresh query...") - conn.execute(refresh_sql) - - # Get actual rows inserted - cursor = conn.execute("SELECT COUNT(*) FROM mv_patient_treatment_summary") - rows_inserted = cursor.fetchone()[0] - - refresh_time = time.time() - start_time - logger.info(f"MV refresh complete: {rows_inserted:,} rows in {refresh_time:.1f}s") - - # Report progress if callback provided - if progress_callback: - progress_callback(rows_inserted, total_patients) - - return MVRefreshResult( - patients_processed=total_patients, - rows_inserted=rows_inserted, - refresh_time_seconds=refresh_time, - success=True - ) - - except Exception as e: - refresh_time = time.time() - start_time - error_msg = str(e) - logger.error(f"MV refresh failed: {error_msg}") - return MVRefreshResult( - patients_processed=0, - rows_inserted=0, - refresh_time_seconds=refresh_time, - success=False, - error_message=error_msg - ) - - -def get_patient_summary_stats(db_manager: Optional[DatabaseManager] = None) -> dict: - """ - Get statistics about the patient treatment summary MV. - - Returns: - Dictionary with MV statistics. - """ - if db_manager is None: - db_manager = DatabaseManager() - - stats = {} - - with db_manager.get_connection() as conn: - # Total rows - cursor = conn.execute("SELECT COUNT(*) FROM mv_patient_treatment_summary") - stats["total_patients"] = cursor.fetchone()[0] - - if stats["total_patients"] == 0: - return stats - - # Aggregated statistics - cursor = conn.execute(""" - SELECT - SUM(total_cost) as total_cost_all, - AVG(total_cost) as avg_cost_per_patient, - SUM(intervention_count) as total_interventions, - AVG(intervention_count) as avg_interventions_per_patient, - AVG(unique_drug_count) as avg_drugs_per_patient, - AVG(days_treated) as avg_days_treated, - MIN(first_seen_date) as earliest_date, - MAX(last_seen_date) as latest_date, - MAX(computed_at) as last_refresh - FROM mv_patient_treatment_summary - """) - result = cursor.fetchone() - - stats["total_cost"] = result[0] if result[0] else 0 - stats["avg_cost_per_patient"] = result[1] if result[1] else 0 - stats["total_interventions"] = result[2] if result[2] else 0 - stats["avg_interventions_per_patient"] = result[3] if result[3] else 0 - stats["avg_drugs_per_patient"] = result[4] if result[4] else 0 - stats["avg_days_treated"] = result[5] if result[5] else 0 - stats["date_range"] = (result[6], result[7]) - stats["last_refresh"] = result[8] - - # Unique directories in MV - cursor = conn.execute("SELECT COUNT(DISTINCT directory) FROM mv_patient_treatment_summary") - stats["unique_directories"] = cursor.fetchone()[0] - - # Unique organizations in MV - cursor = conn.execute("SELECT COUNT(DISTINCT org_name) FROM mv_patient_treatment_summary") - stats["unique_organizations"] = cursor.fetchone()[0] - - return stats - - -def verify_mv_consistency(db_manager: Optional[DatabaseManager] = None) -> tuple[bool, str]: - """ - Verify that the MV is consistent with fact_interventions. - - Checks that: - - Patient counts match - - Total cost sums match - - Intervention counts match - - Returns: - Tuple of (is_consistent, message). - """ - if db_manager is None: - db_manager = DatabaseManager() - - with db_manager.get_connection() as conn: - # Get fact table counts - cursor = conn.execute(""" - SELECT - COUNT(DISTINCT upid) as patients, - SUM(price_actual) as total_cost, - COUNT(*) as interventions - FROM fact_interventions - """) - fact_row = cursor.fetchone() - fact_patients = fact_row[0] or 0 - fact_cost = fact_row[1] or 0 - fact_interventions = fact_row[2] or 0 - - # Get MV counts - cursor = conn.execute(""" - SELECT - COUNT(*) as patients, - SUM(total_cost) as total_cost, - SUM(intervention_count) as interventions - FROM mv_patient_treatment_summary - """) - mv_row = cursor.fetchone() - mv_patients = mv_row[0] or 0 - mv_cost = mv_row[1] or 0 - mv_interventions = mv_row[2] or 0 - - # Compare - issues = [] - - if fact_patients != mv_patients: - issues.append(f"Patient count mismatch: fact={fact_patients:,}, mv={mv_patients:,}") - - if mv_interventions != fact_interventions: - issues.append(f"Intervention count mismatch: fact={fact_interventions:,}, mv={mv_interventions:,}") - - # Allow small floating point differences in cost - cost_diff = abs(fact_cost - mv_cost) - if cost_diff > 0.01: - issues.append(f"Cost mismatch: fact={fact_cost:,.2f}, mv={mv_cost:,.2f}, diff={cost_diff:.2f}") - - if issues: - return False, "; ".join(issues) - - return True, f"MV consistent: {mv_patients:,} patients, {mv_interventions:,} interventions, £{mv_cost:,.2f} total" diff --git a/data_processing/schema.py b/data_processing/schema.py index 04db5cf..7d5ae7c 100644 --- a/data_processing/schema.py +++ b/data_processing/schema.py @@ -115,43 +115,6 @@ CREATE INDEX IF NOT EXISTS idx_ref_drug_indication_clusters_cluster ON ref_drug_ CREATE INDEX IF NOT EXISTS idx_ref_drug_indication_clusters_indication ON ref_drug_indication_clusters(indication); """ -REF_DRUG_SNOMED_MAPPING_SCHEMA = """ --- Direct SNOMED code mapping from drug to indication to GP diagnosis codes --- Source: data/drug_snomed_mapping_enriched.csv (163K rows) --- Used for direct GP record matching to assign diagnosis-based directorates --- and to support indication-based pathway hierarchy (Trust → Search_Term → Drug → Pathway) -CREATE TABLE IF NOT EXISTS ref_drug_snomed_mapping ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - drug_name TEXT NOT NULL, -- Original drug name from mapping - indication TEXT NOT NULL, -- Specific indication (603 unique values) - ta_id TEXT, -- NICE TA reference (e.g., TA568) - search_term TEXT NOT NULL, -- Simplified grouping (187 unique values) - snomed_code TEXT NOT NULL, -- SNOMED CT code for GP record matching - snomed_description TEXT, -- SNOMED code description - cleaned_drug_name TEXT NOT NULL, -- Standardized drug name for matching - primary_directorate TEXT, -- Primary directorate for this indication - all_directorates TEXT, -- Pipe-separated list of valid directorates - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - UNIQUE(drug_name, indication, snomed_code) -); - --- Index for looking up SNOMED codes by drug name (most common access pattern) -CREATE INDEX IF NOT EXISTS idx_ref_drug_snomed_mapping_drug ON ref_drug_snomed_mapping(drug_name); - --- Index for looking up by cleaned drug name (standardized matching) -CREATE INDEX IF NOT EXISTS idx_ref_drug_snomed_mapping_cleaned ON ref_drug_snomed_mapping(cleaned_drug_name); - --- Index for looking up by SNOMED code (reverse lookup from GP record) -CREATE INDEX IF NOT EXISTS idx_ref_drug_snomed_mapping_snomed ON ref_drug_snomed_mapping(snomed_code); - --- Index for grouping by search_term (indication-based hierarchy) -CREATE INDEX IF NOT EXISTS idx_ref_drug_snomed_mapping_search_term ON ref_drug_snomed_mapping(search_term); - --- Composite index for drug + snomed code (common lookup pattern) -CREATE INDEX IF NOT EXISTS idx_ref_drug_snomed_mapping_drug_snomed - ON ref_drug_snomed_mapping(cleaned_drug_name, snomed_code); -""" - # ============================================================================= # Pathway Data Architecture Schemas @@ -278,6 +241,7 @@ CREATE TABLE IF NOT EXISTS pathway_refresh_log ( snowflake_query_date_from TEXT, -- Start date of Snowflake query snowflake_query_date_to TEXT, -- End date of Snowflake query processing_duration_seconds REAL, -- How long the refresh took + source_row_count INTEGER, -- Number of Snowflake rows fetched created_at TEXT DEFAULT CURRENT_TIMESTAMP ); @@ -301,208 +265,6 @@ PATHWAY_TABLES_SCHEMA = f""" """ -# ============================================================================= -# Fact Table Schemas -# ============================================================================= - -FACT_INTERVENTIONS_SCHEMA = """ --- Patient intervention records (fact table) --- Source: HCD activity data (CSV/Parquet files or Snowflake) --- This is the main fact table storing all patient intervention events -CREATE TABLE IF NOT EXISTS fact_interventions ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - - -- Patient identification - upid TEXT NOT NULL, -- Unique Patient ID (Provider Code[:3] + PersonKey) - provider_code TEXT NOT NULL, -- Original provider code (3-5 chars) - person_key TEXT NOT NULL, -- Patient key from source system - - -- Intervention details - drug_name_raw TEXT, -- Original drug name from source - drug_name_std TEXT NOT NULL, -- Standardized drug name (via ref_drug_names) - intervention_date DATE NOT NULL, -- Date of intervention - price_actual REAL NOT NULL DEFAULT 0, -- Cost of intervention in GBP - - -- Organization and directory - org_name TEXT, -- Organization name (cleaned, no commas) - directory TEXT, -- Medical directory/specialty (may be "Undefined") - - -- Source tracking - source_file TEXT, -- Original file this record came from - loaded_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - - -- Additional clinical fields (optional, used in directory fallback logic) - treatment_function_code INTEGER, - additional_detail_1 TEXT, - additional_detail_2 TEXT, - additional_detail_3 TEXT, - additional_detail_4 TEXT, - additional_detail_5 TEXT -); - --- Primary indexes for common filter patterns used in generate_graph() --- UPID: Used for patient grouping, pathway analysis -CREATE INDEX IF NOT EXISTS idx_fact_interventions_upid ON fact_interventions(upid); - --- Drug name (standardized): Used for drug filtering -CREATE INDEX IF NOT EXISTS idx_fact_interventions_drug ON fact_interventions(drug_name_std); - --- Intervention date: Used for date range filtering (start_date, end_date, last_seen) -CREATE INDEX IF NOT EXISTS idx_fact_interventions_date ON fact_interventions(intervention_date); - --- Directory: Used for directory/specialty filtering -CREATE INDEX IF NOT EXISTS idx_fact_interventions_directory ON fact_interventions(directory); - --- Organization: Used for trust filtering (Provider Code maps to org_name) -CREATE INDEX IF NOT EXISTS idx_fact_interventions_org ON fact_interventions(org_name); - --- Composite index for common filter combination (trust + drug + directory) -CREATE INDEX IF NOT EXISTS idx_fact_interventions_composite - ON fact_interventions(org_name, drug_name_std, directory); - --- Composite index for date-based patient analysis -CREATE INDEX IF NOT EXISTS idx_fact_interventions_upid_date - ON fact_interventions(upid, intervention_date); -""" - - -# ============================================================================= -# Materialized View Schemas (Cached Aggregations) -# ============================================================================= - -MV_PATIENT_TREATMENT_SUMMARY_SCHEMA = """ --- Materialized view of patient treatment summaries --- Pre-computed aggregations per patient for faster pathway analysis --- Refreshed when fact_interventions data changes -CREATE TABLE IF NOT EXISTS mv_patient_treatment_summary ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - - -- Patient identification - upid TEXT NOT NULL UNIQUE, -- Unique Patient ID - - -- Organization and directory (for filtering) - org_name TEXT, -- Organization name (first org seen) - directory TEXT, -- Primary directory (first directory assigned) - - -- Date range - first_seen_date DATE NOT NULL, -- First intervention date - last_seen_date DATE NOT NULL, -- Last intervention date - days_treated INTEGER NOT NULL DEFAULT 0, -- Duration: last_seen - first_seen - - -- Cost aggregations - total_cost REAL NOT NULL DEFAULT 0, -- Sum of all intervention costs - avg_cost_per_intervention REAL, -- Average cost per intervention - - -- Treatment summary - intervention_count INTEGER NOT NULL DEFAULT 0, -- Total number of interventions - unique_drug_count INTEGER NOT NULL DEFAULT 0, -- Number of distinct drugs - - -- Drug sequence (pipe-separated standardized drug names in chronological order) - -- Example: "ADALIMUMAB|ETANERCEPT|INFLIXIMAB" - drug_sequence TEXT, - - -- Drug frequency counts (JSON: {"ADALIMUMAB": 5, "ETANERCEPT": 3}) - -- Stores count of each drug for this patient - drug_counts_json TEXT, - - -- Drug cost totals (JSON: {"ADALIMUMAB": 15000.00, "ETANERCEPT": 8000.00}) - -- Stores total cost per drug for this patient - drug_costs_json TEXT, - - -- Per-drug date ranges (JSON: {"ADALIMUMAB": {"first": "2023-01-01", "last": "2023-06-15"}, ...}) - -- Stores first/last date for each drug - drug_date_ranges_json TEXT, - - -- Metadata - computed_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - source_row_count INTEGER -- Number of fact_interventions rows used -); - --- Index for fast patient lookup -CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_upid ON mv_patient_treatment_summary(upid); - --- Indexes for common filter patterns -CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_org ON mv_patient_treatment_summary(org_name); -CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_directory ON mv_patient_treatment_summary(directory); -CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_first_seen ON mv_patient_treatment_summary(first_seen_date); -CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_last_seen ON mv_patient_treatment_summary(last_seen_date); - --- Composite index for date range filtering (common in generate_graph) -CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_date_range - ON mv_patient_treatment_summary(first_seen_date, last_seen_date); - --- Composite index for org + directory + dates (full filter pattern) -CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_filter_composite - ON mv_patient_treatment_summary(org_name, directory, first_seen_date, last_seen_date); - --- Index for drug sequence pattern matching -CREATE INDEX IF NOT EXISTS idx_mv_patient_summary_drug_seq ON mv_patient_treatment_summary(drug_sequence); -""" - -MATERIALIZED_VIEWS_SCHEMA = f""" --- Materialized Views Schema --- Pre-computed aggregations for performance - -{MV_PATIENT_TREATMENT_SUMMARY_SCHEMA} -""" - - -# ============================================================================= -# File Tracking Schemas (Incremental Updates) -# ============================================================================= - -PROCESSED_FILES_SCHEMA = """ --- Tracks processed data files for incremental updates --- Enables detecting changed files by comparing hashes --- Stores processing status and statistics -CREATE TABLE IF NOT EXISTS processed_files ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - - -- File identification - file_path TEXT NOT NULL, -- Full path to the file - file_name TEXT NOT NULL, -- Just the filename (for display) - file_hash TEXT NOT NULL, -- SHA256 hash of file contents - - -- File metadata - file_size_bytes INTEGER, -- Size of file in bytes - file_modified_at TIMESTAMP, -- File's last modification timestamp - - -- Processing results - row_count INTEGER DEFAULT 0, -- Number of rows processed from this file - status TEXT NOT NULL DEFAULT 'pending', -- pending, processing, success, error - error_message TEXT, -- Error details if status='error' - - -- Timestamps - first_processed_at TIMESTAMP, -- When first processed - last_processed_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - processing_duration_seconds REAL, -- How long processing took - - -- Uniqueness: only one record per file path - -- Hash changes indicate file content changed (needs reprocessing) - UNIQUE(file_path) -); - --- Index for fast lookup by file path -CREATE INDEX IF NOT EXISTS idx_processed_files_path ON processed_files(file_path); - --- Index for finding files by status (e.g., find all pending or errored files) -CREATE INDEX IF NOT EXISTS idx_processed_files_status ON processed_files(status); - --- Index for finding files by hash (detect if same file appears at different paths) -CREATE INDEX IF NOT EXISTS idx_processed_files_hash ON processed_files(file_hash); - --- Index for finding recently processed files -CREATE INDEX IF NOT EXISTS idx_processed_files_last_processed ON processed_files(last_processed_at); -""" - -FILE_TRACKING_SCHEMA = f""" --- File Tracking Schema --- Supports incremental data loading - -{PROCESSED_FILES_SCHEMA} -""" - - # ============================================================================= # Combined Schemas # ============================================================================= @@ -520,29 +282,14 @@ REFERENCE_TABLES_SCHEMA = f""" {REF_DRUG_DIRECTORY_MAP_SCHEMA} {REF_DRUG_INDICATION_CLUSTERS_SCHEMA} - -{REF_DRUG_SNOMED_MAPPING_SCHEMA} -""" - -FACT_TABLES_SCHEMA = f""" --- Fact Tables Schema --- Contains patient intervention data - -{FACT_INTERVENTIONS_SCHEMA} """ ALL_TABLES_SCHEMA = f""" -- Complete Database Schema --- Reference tables + Fact tables + Materialized views + File tracking + Pathway tables +-- Reference tables + Pathway tables {REFERENCE_TABLES_SCHEMA} -{FACT_TABLES_SCHEMA} - -{MATERIALIZED_VIEWS_SCHEMA} - -{FILE_TRACKING_SCHEMA} - {PATHWAY_TABLES_SCHEMA} """ @@ -580,26 +327,10 @@ def drop_reference_tables(conn: sqlite3.Connection) -> None: DROP TABLE IF EXISTS ref_directories; DROP TABLE IF EXISTS ref_drug_directory_map; DROP TABLE IF EXISTS ref_drug_indication_clusters; - DROP TABLE IF EXISTS ref_drug_snomed_mapping; """) logger.info("Reference tables dropped") -def create_drug_snomed_mapping_table(conn: sqlite3.Connection) -> None: - """ - Create the ref_drug_snomed_mapping table for direct SNOMED code mapping. - - This table stores mappings from drugs to SNOMED codes for GP record matching, - enabling diagnosis-based directorate assignment and indication-based pathways. - - Args: - conn: SQLite database connection. - """ - logger.info("Creating ref_drug_snomed_mapping table...") - conn.executescript(REF_DRUG_SNOMED_MAPPING_SCHEMA) - logger.info("ref_drug_snomed_mapping table created successfully") - - def get_reference_table_counts(conn: sqlite3.Connection) -> dict[str, int]: """ Get row counts for all reference tables. @@ -616,7 +347,6 @@ def get_reference_table_counts(conn: sqlite3.Connection) -> dict[str, int]: "ref_directories", "ref_drug_directory_map", "ref_drug_indication_clusters", - "ref_drug_snomed_mapping", ] counts = {} @@ -647,7 +377,6 @@ def verify_reference_tables_exist(conn: sqlite3.Connection) -> list[str]: "ref_directories", "ref_drug_directory_map", "ref_drug_indication_clusters", - "ref_drug_snomed_mapping", ] missing = [] @@ -662,164 +391,6 @@ def verify_reference_tables_exist(conn: sqlite3.Connection) -> list[str]: return missing -# ============================================================================= -# Fact Table Helper Functions -# ============================================================================= - -def create_fact_tables(conn: sqlite3.Connection) -> None: - """ - Create all fact tables in the database (including materialized views). - - Args: - conn: SQLite database connection. - """ - logger.info("Creating fact tables...") - conn.executescript(FACT_TABLES_SCHEMA) - conn.executescript(MATERIALIZED_VIEWS_SCHEMA) - logger.info("Fact tables created successfully") - - -def drop_fact_tables(conn: sqlite3.Connection) -> None: - """ - Drop all fact tables from the database. - - Args: - conn: SQLite database connection. - - Warning: - This will delete all patient intervention data. Use with caution. - """ - logger.warning("Dropping fact tables...") - conn.executescript(""" - DROP TABLE IF EXISTS fact_interventions; - DROP TABLE IF EXISTS mv_patient_treatment_summary; - """) - logger.info("Fact tables dropped") - - -def get_fact_table_counts(conn: sqlite3.Connection) -> dict[str, int]: - """ - Get row counts for all fact tables (including materialized views). - - Args: - conn: SQLite database connection. - - Returns: - Dictionary mapping table name to row count. - """ - tables = ["fact_interventions", "mv_patient_treatment_summary"] - counts = {} - - for table in tables: - cursor = conn.execute(f"SELECT COUNT(*) FROM {table}") - result = cursor.fetchone() - counts[table] = result[0] if result else 0 - - return counts - - -def verify_fact_tables_exist(conn: sqlite3.Connection) -> list[str]: - """ - Verify that all fact tables exist (including materialized views). - - Args: - conn: SQLite database connection. - - Returns: - List of missing table names. Empty list means all tables exist. - """ - required_tables = ["fact_interventions", "mv_patient_treatment_summary"] - missing = [] - - for table in required_tables: - cursor = conn.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name=?", - (table,) - ) - if cursor.fetchone() is None: - missing.append(table) - - return missing - - -# ============================================================================= -# File Tracking Helper Functions -# ============================================================================= - -def create_file_tracking_tables(conn: sqlite3.Connection) -> None: - """ - Create file tracking tables in the database. - - Args: - conn: SQLite database connection. - """ - logger.info("Creating file tracking tables...") - conn.executescript(FILE_TRACKING_SCHEMA) - logger.info("File tracking tables created successfully") - - -def drop_file_tracking_tables(conn: sqlite3.Connection) -> None: - """ - Drop file tracking tables from the database. - - Args: - conn: SQLite database connection. - - Warning: - This will delete all file tracking history. - """ - logger.warning("Dropping file tracking tables...") - conn.executescript(""" - DROP TABLE IF EXISTS processed_files; - """) - logger.info("File tracking tables dropped") - - -def get_file_tracking_counts(conn: sqlite3.Connection) -> dict[str, int]: - """ - Get row counts for file tracking tables. - - Args: - conn: SQLite database connection. - - Returns: - Dictionary mapping table name to row count. - """ - tables = ["processed_files"] - counts = {} - - for table in tables: - cursor = conn.execute(f"SELECT COUNT(*) FROM {table}") - result = cursor.fetchone() - counts[table] = result[0] if result else 0 - - return counts - - -def verify_file_tracking_tables_exist(conn: sqlite3.Connection) -> list[str]: - """ - Verify that file tracking tables exist. - - Args: - conn: SQLite database connection. - - Returns: - List of missing table names. Empty list means all tables exist. - """ - required_tables = ["processed_files"] - missing = [] - - for table in required_tables: - cursor = conn.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name=?", - (table,) - ) - if cursor.fetchone() is None: - missing.append(table) - - return missing - - # ============================================================================= # Pathway Table Helper Functions # ============================================================================= @@ -1050,13 +621,37 @@ def migrate_pathway_nodes_chart_type(conn: sqlite3.Connection) -> tuple[bool, st return False, f"Migration failed: {e}" +def migrate_refresh_log_source_row_count(conn: sqlite3.Connection) -> tuple[bool, str]: + """Add source_row_count column to pathway_refresh_log if it doesn't exist. + + This column stores the Snowflake row count for display in the UI footer. + """ + cursor = conn.execute("PRAGMA table_info(pathway_refresh_log)") + columns = [row[1] for row in cursor.fetchall()] + + if "source_row_count" in columns: + return True, "source_row_count column already exists" + + logger.info("Adding source_row_count column to pathway_refresh_log...") + try: + conn.execute(""" + ALTER TABLE pathway_refresh_log + ADD COLUMN source_row_count INTEGER + """) + conn.commit() + return True, "Added source_row_count column" + except Exception as e: + logger.error(f"Failed to add source_row_count column: {e}") + return False, f"Migration failed: {e}" + + # ============================================================================= # Combined Helper Functions # ============================================================================= def create_all_tables(conn: sqlite3.Connection) -> None: """ - Create all tables (reference + fact) in the database. + Create all tables (reference + pathway) in the database. Args: conn: SQLite database connection. @@ -1078,8 +673,6 @@ def drop_all_tables(conn: sqlite3.Connection) -> None: """ logger.warning("Dropping all tables...") drop_pathway_tables(conn) - drop_file_tracking_tables(conn) - drop_fact_tables(conn) drop_reference_tables(conn) logger.info("All tables dropped") @@ -1096,8 +689,6 @@ def get_all_table_counts(conn: sqlite3.Connection) -> dict[str, int]: """ counts = {} counts.update(get_reference_table_counts(conn)) - counts.update(get_fact_table_counts(conn)) - counts.update(get_file_tracking_counts(conn)) counts.update(get_pathway_table_counts(conn)) return counts @@ -1114,7 +705,5 @@ def verify_all_tables_exist(conn: sqlite3.Connection) -> list[str]: """ missing = [] missing.extend(verify_reference_tables_exist(conn)) - missing.extend(verify_fact_tables_exist(conn)) - missing.extend(verify_file_tracking_tables_exist(conn)) missing.extend(verify_pathway_tables_exist(conn)) return missing diff --git a/pathways_app/pathways_app.py b/pathways_app/pathways_app.py index 29682cc..acab390 100644 --- a/pathways_app/pathways_app.py +++ b/pathways_app/pathways_app.py @@ -5,7 +5,7 @@ Single-page dashboard with reactive filtering and real-time chart updates. Design reference: DESIGN_SYSTEM.md """ -from datetime import datetime, timedelta +from datetime import datetime from pathlib import Path from typing import Any @@ -69,14 +69,6 @@ class AppState(rx.State): # Data freshness tracking last_updated: str = "" # ISO format timestamp of last data load - # Raw data storage - list of dicts (Reflex-friendly) - # Each dict represents a patient record with keys like: - # UPID, Drug Name, Intervention Date, Price Actual, Directory, etc. - raw_data: list[dict[str, Any]] = [] - - # Latest date in dataset (detected on load, used for "to" date defaults) - latest_date_in_data: str = "" - # ========================================================================= # UI State Variables # ========================================================================= @@ -111,22 +103,7 @@ class AppState(rx.State): {"value": "12mo", "label": "Last 12 months"}, ] - # Legacy date filter state (kept for backwards compatibility, will be removed) - # Filter toggle state - initiated_filter_enabled: bool = False - last_seen_filter_enabled: bool = True - - # Date filter values (ISO format strings YYYY-MM-DD) - # Initiated filter: Defaults empty (filter is OFF by default) - initiated_from_date: str = "" - initiated_to_date: str = "" - - # Last Seen filter: Defaults to last 6 months (filter is ON by default) - # These will be updated on data load to use actual latest date - last_seen_from_date: str = (datetime.now() - timedelta(days=180)).strftime("%Y-%m-%d") - last_seen_to_date: str = datetime.now().strftime("%Y-%m-%d") - - # Available options for dropdowns (populated from data in Phase 3) + # Available options for dropdowns (populated from data) available_drugs: list[str] = ["Drug A", "Drug B", "Drug C", "Drug D", "Drug E"] available_indications: list[str] = ["Indication 1", "Indication 2", "Indication 3"] available_directorates: list[str] = ["Medical", "Surgical", "Oncology", "Rheumatology"] @@ -146,45 +123,7 @@ class AppState(rx.State): indication_dropdown_open: bool = False directorate_dropdown_open: bool = False - # Event handlers for filter toggles - def toggle_initiated_filter(self): - """Toggle initiated date filter on/off.""" - self.initiated_filter_enabled = not self.initiated_filter_enabled - if self.data_loaded: - self.apply_filters() - - def toggle_last_seen_filter(self): - """Toggle last seen date filter on/off.""" - self.last_seen_filter_enabled = not self.last_seen_filter_enabled - if self.data_loaded: - self.apply_filters() - - # Event handlers for date changes - def set_initiated_from(self, value: str): - """Set initiated from date.""" - self.initiated_from_date = value - if self.data_loaded: - self.apply_filters() - - def set_initiated_to(self, value: str): - """Set initiated to date.""" - self.initiated_to_date = value - if self.data_loaded: - self.apply_filters() - - def set_last_seen_from(self, value: str): - """Set last seen from date.""" - self.last_seen_from_date = value - if self.data_loaded: - self.apply_filters() - - def set_last_seen_to(self, value: str): - """Set last seen to date.""" - self.last_seen_to_date = value - if self.data_loaded: - self.apply_filters() - - # Event handlers for date filter dropdowns (new pathway_nodes approach) + # Event handlers for date filter dropdowns def set_initiated_filter(self, value: str): """Set initiated filter dropdown value.""" self.selected_initiated = value @@ -461,141 +400,6 @@ class AppState(rx.State): # Filter Logic Methods # ========================================================================= - def apply_filters(self): - """ - Apply current filter state to data and update KPI values. - - This method queries the SQLite database with the current filter settings: - - Initiated date filter: filters patients whose FIRST intervention date is within range - - Last Seen date filter: filters patients whose LAST intervention date is within range - - Drug filter: filters by selected drugs (empty = all) - - Directorate filter: filters by selected directorates (empty = all) - - Note: Indication filter is not implemented at the database level since indications - are derived from drug mappings, not stored directly in fact_interventions. - - Updates: unique_patients, total_drugs, total_cost, and filtered_record_count - """ - import sqlite3 - - db_path = Path("data/pathways.db") - - if not db_path.exists(): - self.error_message = "Unable to connect to database. Please ensure data has been loaded." - return - - try: - conn = sqlite3.connect(str(db_path)) - cursor = conn.cursor() - - # Build the filter query dynamically - # We use a CTE to compute first_seen and last_seen dates per patient, - # then filter based on those dates if date filters are enabled - - where_clauses = [] - params = [] - - # Drug filter (if any drugs selected) - if self.selected_drugs: - placeholders = ",".join("?" * len(self.selected_drugs)) - where_clauses.append(f"drug_name_std IN ({placeholders})") - params.extend(self.selected_drugs) - - # Directorate filter (if any directorates selected) - if self.selected_directorates: - placeholders = ",".join("?" * len(self.selected_directorates)) - where_clauses.append(f"directory IN ({placeholders})") - params.extend(self.selected_directorates) - - # Build WHERE clause for base data filtering - base_where = "" - if where_clauses: - base_where = "WHERE " + " AND ".join(where_clauses) - - # Date filter logic: - # - "Initiated" filters patients whose FIRST intervention is within the date range - # - "Last Seen" filters patients whose LAST intervention is within the date range - # We need to use a subquery to compute patient-level date ranges - - having_clauses = [] - having_params = [] - - # Initiated filter (when enabled) - if self.initiated_filter_enabled and self.initiated_from_date: - having_clauses.append("first_seen_date >= ?") - having_params.append(self.initiated_from_date) - if self.initiated_filter_enabled and self.initiated_to_date: - having_clauses.append("first_seen_date <= ?") - having_params.append(self.initiated_to_date) - - # Last Seen filter (when enabled) - if self.last_seen_filter_enabled and self.last_seen_from_date: - having_clauses.append("last_seen_date >= ?") - having_params.append(self.last_seen_from_date) - if self.last_seen_filter_enabled and self.last_seen_to_date: - having_clauses.append("last_seen_date <= ?") - having_params.append(self.last_seen_to_date) - - having_clause = "" - if having_clauses: - having_clause = "HAVING " + " AND ".join(having_clauses) - - # Query to get filtered patient UPIDs - # This computes per-patient first/last seen dates and filters accordingly - patient_filter_query = f""" - WITH patient_dates AS ( - SELECT - upid, - MIN(intervention_date) as first_seen_date, - MAX(intervention_date) as last_seen_date - FROM fact_interventions - {base_where} - GROUP BY upid - {having_clause} - ) - SELECT upid FROM patient_dates - """ - - # Now get KPI values for filtered patients - kpi_query = f""" - WITH filtered_patients AS ( - {patient_filter_query} - ) - SELECT - COUNT(DISTINCT f.upid) as unique_patients, - COUNT(DISTINCT f.drug_name_std) as unique_drugs, - COALESCE(SUM(f.price_actual), 0) as total_cost, - COUNT(*) as record_count - FROM fact_interventions f - INNER JOIN filtered_patients fp ON f.upid = fp.upid - {base_where.replace('WHERE', 'AND') if base_where else ''} - """ - - # Combine all params: base params for CTE, having params, then base params again for final join - all_params = params + having_params - if where_clauses: - all_params.extend(params) # For the AND conditions in the final query - - cursor.execute(kpi_query, all_params) - result = cursor.fetchone() - - if result: - self.unique_patients = result[0] or 0 - self.total_drugs = result[1] or 0 - self.total_cost = float(result[2]) if result[2] else 0.0 - # Note: filtered_record_count could be stored if needed - - conn.close() - self.error_message = "" - - # Update chart data with new filtered results - self.prepare_chart_data() - - except sqlite3.Error as e: - self.error_message = f"Unable to filter data. Database error: {str(e)}" - except Exception as e: - self.error_message = f"An unexpected error occurred while filtering. Details: {str(e)}" - # ========================================================================= # Data Loading Methods # ========================================================================= @@ -604,11 +408,8 @@ class AppState(rx.State): """ Load data from SQLite database on app initialization. - This method: - 1. Connects to the SQLite database (data/pathways.db) - 2. Loads available drugs, indications, directorates from actual data - 3. Detects the latest date in the dataset for "to" date defaults - 4. Updates total_records, last_updated, and data_loaded state + Sources available drugs/directorates from pathway_nodes and total_records + from the latest pathway_refresh_log entry. """ import sqlite3 @@ -622,30 +423,38 @@ class AppState(rx.State): conn = sqlite3.connect(str(db_path)) cursor = conn.cursor() - # Get total records - cursor.execute("SELECT COUNT(*) FROM fact_interventions") - self.total_records = cursor.fetchone()[0] - - if self.total_records == 0: - self.error_message = "The database is empty. No patient records found." - conn.close() - return - - # Get available drugs (distinct, sorted) + # Get total source records from latest completed refresh log cursor.execute(""" - SELECT DISTINCT drug_name_std - FROM fact_interventions - WHERE drug_name_std IS NOT NULL AND drug_name_std != '' - ORDER BY drug_name_std + SELECT source_row_count, completed_at + FROM pathway_refresh_log + WHERE status = 'completed' + ORDER BY started_at DESC + LIMIT 1 + """) + refresh_row = cursor.fetchone() + if refresh_row: + self.total_records = refresh_row[0] or 0 + if refresh_row[1]: + self.last_updated = refresh_row[1] + else: + self.total_records = 0 + + # Get available drugs from pathway_nodes (level 3 = drug nodes) + cursor.execute(""" + SELECT DISTINCT labels + FROM pathway_nodes + WHERE level = 3 AND labels IS NOT NULL AND labels != '' + ORDER BY labels """) self.available_drugs = [row[0] for row in cursor.fetchall()] - # Get available directories (distinct, sorted) + # Get available directorates from directory chart pathway_nodes (level 2) cursor.execute(""" - SELECT DISTINCT directory - FROM fact_interventions - WHERE directory IS NOT NULL AND directory != '' - ORDER BY directory + SELECT DISTINCT labels + FROM pathway_nodes + WHERE level = 2 AND chart_type = 'directory' + AND labels IS NOT NULL AND labels != '' + ORDER BY labels """) self.available_directorates = [row[0] for row in cursor.fetchall()] @@ -658,50 +467,17 @@ class AppState(rx.State): """) self.available_indications = [row[0] for row in cursor.fetchall()] - # If no indications in reference table, use placeholder if not self.available_indications: self.available_indications = ["(No indications available)"] - # Get date range from data - cursor.execute(""" - SELECT MIN(intervention_date), MAX(intervention_date) - FROM fact_interventions - """) - date_range = cursor.fetchone() - min_date, max_date = date_range - - # Update latest_date_in_data and set "to" date defaults - if max_date: - self.latest_date_in_data = max_date - self.last_seen_to_date = max_date - self.initiated_to_date = max_date - - # Set "from" date for last_seen filter (6 months before max_date) - max_dt = datetime.strptime(max_date, "%Y-%m-%d") - six_months_ago = max_dt - timedelta(days=180) - self.last_seen_from_date = six_months_ago.strftime("%Y-%m-%d") - - # Get unique patient count for KPIs - cursor.execute("SELECT COUNT(DISTINCT upid) FROM fact_interventions") - self.unique_patients = cursor.fetchone()[0] - - # Get unique drug count - self.total_drugs = len(self.available_drugs) - - # Get total cost - cursor.execute("SELECT SUM(price_actual) FROM fact_interventions") - total_cost_result = cursor.fetchone()[0] - self.total_cost = float(total_cost_result) if total_cost_result else 0.0 - conn.close() - # Set data_loaded and last_updated self.data_loaded = True - self.last_updated = datetime.now().isoformat() + if not self.last_updated: + self.last_updated = datetime.now().isoformat() self.error_message = "" # Load pre-computed pathway data for the default date filter - # This replaces apply_filters() which used dynamic calculation self.load_pathway_data() except sqlite3.Error as e: @@ -986,269 +762,6 @@ class AppState(rx.State): chart_data: list[dict[str, Any]] = [] chart_title: str = "" - def prepare_chart_data(self): - """ - Prepare hierarchical data for Plotly icicle chart. - - This method queries the filtered patient data and transforms it into - a hierarchical structure: Root → Trust → Directory → Drug - - The chart data is stored in self.chart_data as a list of dicts with: - - parents: Parent node identifier - - ids: Unique node identifier (hierarchical path) - - labels: Display label - - value: Patient count - - cost: Total cost - - colour: Color value (proportion of parent) - - Updates: chart_data, chart_title, chart_loading - """ - import sqlite3 - - db_path = Path("data/pathways.db") - - if not db_path.exists(): - self.error_message = "Unable to generate chart. Database not found." - self.chart_data = [] - return - - self.chart_loading = True - - try: - conn = sqlite3.connect(str(db_path)) - cursor = conn.cursor() - - # Build WHERE clause for filters - where_clauses = [] - params = [] - - # Drug filter (if any drugs selected) - if self.selected_drugs: - placeholders = ",".join("?" * len(self.selected_drugs)) - where_clauses.append(f"drug_name_std IN ({placeholders})") - params.extend(self.selected_drugs) - - # Directorate filter (if any directorates selected) - if self.selected_directorates: - placeholders = ",".join("?" * len(self.selected_directorates)) - where_clauses.append(f"directory IN ({placeholders})") - params.extend(self.selected_directorates) - - base_where = "" - if where_clauses: - base_where = "WHERE " + " AND ".join(where_clauses) - - # Build date filter HAVING clauses for patient-level filtering - having_clauses = [] - having_params = [] - - if self.initiated_filter_enabled and self.initiated_from_date: - having_clauses.append("first_seen >= ?") - having_params.append(self.initiated_from_date) - if self.initiated_filter_enabled and self.initiated_to_date: - having_clauses.append("first_seen <= ?") - having_params.append(self.initiated_to_date) - - if self.last_seen_filter_enabled and self.last_seen_from_date: - having_clauses.append("last_seen >= ?") - having_params.append(self.last_seen_from_date) - if self.last_seen_filter_enabled and self.last_seen_to_date: - having_clauses.append("last_seen <= ?") - having_params.append(self.last_seen_to_date) - - having_clause = "" - if having_clauses: - having_clause = "HAVING " + " AND ".join(having_clauses) - - # Query to get aggregated data by Trust -> Directory -> Drug - # fact_interventions already has org_name, use it directly - chart_query = f""" - WITH filtered_patients AS ( - SELECT upid - FROM ( - SELECT - upid, - MIN(intervention_date) as first_seen, - MAX(intervention_date) as last_seen - FROM fact_interventions - {base_where} - GROUP BY upid - {having_clause} - ) - ), - patient_records AS ( - SELECT - f.upid, - COALESCE(f.org_name, f.provider_code) as trust_name, - f.directory, - f.drug_name_std, - f.price_actual - FROM fact_interventions f - INNER JOIN filtered_patients fp ON f.upid = fp.upid - {base_where.replace('WHERE', 'AND') if base_where else ''} - ) - SELECT - trust_name, - directory, - drug_name_std, - COUNT(DISTINCT upid) as patient_count, - COALESCE(SUM(price_actual), 0) as total_cost - FROM patient_records - GROUP BY trust_name, directory, drug_name_std - ORDER BY trust_name, directory, drug_name_std - """ - - all_params = params + having_params - if where_clauses: - all_params.extend(params) - - cursor.execute(chart_query, all_params) - rows = cursor.fetchall() - - conn.close() - - # Build hierarchical chart data - chart_data = [] - hierarchy_totals = {} # Track totals for calculating color values - - # Root node - root_id = "N&WICS" - chart_data.append({ - "parents": "", - "ids": root_id, - "labels": "Norfolk & Waveney ICS", - "value": 0, - "cost": 0.0, - "colour": 1.0, - }) - - # Process rows to build hierarchy - trust_totals = {} - directory_totals = {} - drug_data = [] - - for row in rows: - trust_name, directory, drug_name, patient_count, cost = row - - if not trust_name or not directory or not drug_name: - continue - - # Trust level - trust_id = f"{root_id} - {trust_name}" - if trust_id not in trust_totals: - trust_totals[trust_id] = {"value": 0, "cost": 0.0, "label": trust_name} - trust_totals[trust_id]["value"] += patient_count - trust_totals[trust_id]["cost"] += cost - - # Directory level - dir_id = f"{trust_id} - {directory}" - if dir_id not in directory_totals: - directory_totals[dir_id] = { - "value": 0, - "cost": 0.0, - "label": directory, - "parent": trust_id, - } - directory_totals[dir_id]["value"] += patient_count - directory_totals[dir_id]["cost"] += cost - - # Drug level (leaf) - drug_id = f"{dir_id} - {drug_name}" - drug_data.append({ - "ids": drug_id, - "labels": drug_name, - "parent": dir_id, - "value": patient_count, - "cost": float(cost), - }) - - # Calculate root total - root_total = sum(t["value"] for t in trust_totals.values()) - root_cost = sum(t["cost"] for t in trust_totals.values()) - chart_data[0]["value"] = root_total - chart_data[0]["cost"] = root_cost - - # Add trust nodes with color proportions - for trust_id, data in trust_totals.items(): - colour = data["value"] / root_total if root_total > 0 else 0 - chart_data.append({ - "parents": root_id, - "ids": trust_id, - "labels": data["label"], - "value": data["value"], - "cost": data["cost"], - "colour": colour, - }) - - # Add directory nodes with color proportions - for dir_id, data in directory_totals.items(): - parent_total = trust_totals[data["parent"]]["value"] - colour = data["value"] / parent_total if parent_total > 0 else 0 - chart_data.append({ - "parents": data["parent"], - "ids": dir_id, - "labels": data["label"], - "value": data["value"], - "cost": data["cost"], - "colour": colour, - }) - - # Add drug nodes with color proportions - for drug in drug_data: - parent_dir = drug["parent"] - parent_total = directory_totals[parent_dir]["value"] - colour = drug["value"] / parent_total if parent_total > 0 else 0 - chart_data.append({ - "parents": parent_dir, - "ids": drug["ids"], - "labels": drug["labels"], - "value": drug["value"], - "cost": drug["cost"], - "colour": colour, - }) - - self.chart_data = chart_data - self.chart_title = self._generate_chart_title() - self.chart_loading = False - self.error_message = "" - - except sqlite3.Error as e: - self.error_message = f"Unable to generate chart. Database error: {str(e)}" - self.chart_data = [] - self.chart_loading = False - except Exception as e: - self.error_message = f"Unable to generate chart. Details: {str(e)}" - self.chart_data = [] - self.chart_loading = False - - def _generate_chart_title(self) -> str: - """Generate chart title based on current filter state.""" - parts = [] - - # Date range info - if self.last_seen_filter_enabled: - parts.append(f"Last seen: {self.last_seen_from_date} to {self.last_seen_to_date}") - elif self.initiated_filter_enabled: - parts.append(f"Initiated: {self.initiated_from_date} to {self.initiated_to_date}") - - # Drug selection info - if self.selected_drugs: - if len(self.selected_drugs) <= 3: - parts.append(", ".join(self.selected_drugs)) - else: - parts.append(f"{len(self.selected_drugs)} drugs selected") - - # Directorate selection info - if self.selected_directorates: - if len(self.selected_directorates) <= 2: - parts.append(", ".join(self.selected_directorates)) - else: - parts.append(f"{len(self.selected_directorates)} directorates") - - if parts: - return " | ".join(parts) - return "All Patients" - # ========================================================================= # Plotly Chart Generation # ========================================================================= diff --git a/tests/test_large_dataset_performance.py b/tests/test_large_dataset_performance.py deleted file mode 100644 index 53d338c..0000000 --- a/tests/test_large_dataset_performance.py +++ /dev/null @@ -1,446 +0,0 @@ -""" -Large dataset performance tests for the Patient Pathway Analysis tool. - -This module tests the system's ability to handle realistic workloads: -1. Full dataset analysis (all drugs, trusts, directories) -2. Memory usage under load -3. Scalability characteristics - -Run with: python -m pytest tests/test_large_dataset_performance.py -v -""" - -import gc -import time -import tracemalloc -from datetime import date -from pathlib import Path - -import pytest - -# Mark all tests in this module as large dataset tests -pytestmark = pytest.mark.largedata - - -class TestLargeDatasetPerformance: - """Performance tests with full dataset.""" - - @pytest.fixture(autouse=True) - def setup_paths(self): - """Set up paths and verify data exists.""" - from core import default_paths - from data_processing import get_loader - - # Check if database exists - db_path = default_paths.data_dir / "pathways.db" - if not db_path.exists(): - pytest.skip("SQLite database not found") - - self.paths = default_paths - self.loader = get_loader('sqlite') - - # Load data once - result = self.loader.load() - if result is None or result.df is None or len(result.df) == 0: - pytest.skip("No data available in database") - - self.df = result.df - self.row_count = result.row_count - - def test_data_load_time_acceptable(self): - """Data loading should complete in under 5 seconds.""" - from data_processing import get_loader - - gc.collect() - start = time.perf_counter() - loader = get_loader('sqlite') - result = loader.load() - elapsed = time.perf_counter() - start - - assert result is not None, "Data loading failed" - assert result.row_count > 0, "No data loaded" - # Allow 5 seconds for data loading - assert elapsed < 5.0, f"Data loading took {elapsed:.2f}s (target: <5s)" - - def test_analysis_pipeline_completes(self): - """Full analysis pipeline should complete without error.""" - from analysis.pathway_analyzer import generate_icicle_chart - import pandas as pd - - # Get available filters from actual data - trusts = self.df['Provider Code'].unique().tolist()[:20] - drugs = self.df['Drug Name'].dropna().unique().tolist()[:10] - directories = self.df['Directory'].dropna().unique().tolist() - - # Load org codes for trust name mapping - org_codes = pd.read_csv(self.paths.org_codes_csv, index_col=1) - trust_names = [] - for t in trusts: - if t in org_codes.index: - trust_names.append(org_codes.loc[t, 'Name']) - if not trust_names: - trust_names = org_codes['Name'].tolist()[:20] - - # Run analysis with reasonable filter - ice_df, title = generate_icicle_chart( - df=self.df, - start_date="2020-01-01", - end_date="2025-01-01", - last_seen_date="2020-01-01", - trust_filter=trust_names, - drug_filter=drugs, - directory_filter=directories, - minimum_num_patients=1, - title="Large Dataset Test", - paths=self.paths, - ) - - # Should produce some results - assert ice_df is not None, "Analysis produced no results" - assert len(ice_df) > 0, "Analysis produced empty results" - - def test_analysis_pipeline_time_acceptable(self): - """Analysis pipeline should complete in under 60 seconds.""" - from analysis.pathway_analyzer import generate_icicle_chart - import pandas as pd - - # Get available filters from actual data - trusts = self.df['Provider Code'].unique().tolist()[:20] - drugs = self.df['Drug Name'].dropna().unique().tolist()[:10] - directories = self.df['Directory'].dropna().unique().tolist() - - # Load org codes for trust name mapping - org_codes = pd.read_csv(self.paths.org_codes_csv, index_col=1) - trust_names = [] - for t in trusts: - if t in org_codes.index: - trust_names.append(org_codes.loc[t, 'Name']) - if not trust_names: - trust_names = org_codes['Name'].tolist()[:20] - - gc.collect() - start = time.perf_counter() - - ice_df, title = generate_icicle_chart( - df=self.df, - start_date="2020-01-01", - end_date="2025-01-01", - last_seen_date="2020-01-01", - trust_filter=trust_names, - drug_filter=drugs, - directory_filter=directories, - minimum_num_patients=1, - title="Performance Test", - paths=self.paths, - ) - - elapsed = time.perf_counter() - start - - # Allow 60 seconds for full analysis (observed ~19s with 440K rows) - assert elapsed < 60.0, f"Analysis took {elapsed:.2f}s (target: <60s)" - print(f"\n Analysis completed in {elapsed:.2f}s with {len(ice_df) if ice_df is not None else 0} result rows") - - def test_memory_usage_acceptable(self): - """Memory usage should not exceed 500MB during analysis.""" - from analysis.pathway_analyzer import generate_icicle_chart - import pandas as pd - - # Get available filters from actual data - trusts = self.df['Provider Code'].unique().tolist()[:15] - drugs = self.df['Drug Name'].dropna().unique().tolist()[:5] - directories = self.df['Directory'].dropna().unique().tolist() - - # Load org codes for trust name mapping - org_codes = pd.read_csv(self.paths.org_codes_csv, index_col=1) - trust_names = [] - for t in trusts: - if t in org_codes.index: - trust_names.append(org_codes.loc[t, 'Name']) - if not trust_names: - trust_names = org_codes['Name'].tolist()[:15] - - gc.collect() - tracemalloc.start() - - ice_df, title = generate_icicle_chart( - df=self.df, - start_date="2020-01-01", - end_date="2025-01-01", - last_seen_date="2020-01-01", - trust_filter=trust_names, - drug_filter=drugs, - directory_filter=directories, - minimum_num_patients=1, - title="Memory Test", - paths=self.paths, - ) - - current, peak = tracemalloc.get_traced_memory() - tracemalloc.stop() - - peak_mb = peak / 1024 / 1024 - - # Allow 500MB peak memory - assert peak_mb < 500, f"Peak memory {peak_mb:.1f}MB exceeds 500MB limit" - print(f"\n Peak memory usage: {peak_mb:.1f}MB") - - def test_figure_creation_scales(self): - """Figure creation time should scale linearly with result size.""" - from visualization.plotly_generator import create_icicle_figure - import pandas as pd - import numpy as np - - # Test with different sizes - sizes = [100, 500, 1000, 2000] - times = [] - - for n_rows in sizes: - sample_data = { - 'parents': ['N&WICS'] * n_rows, - 'ids': [f'N&WICS - Test{i}' for i in range(n_rows)], - 'labels': [f'Test{i}' for i in range(n_rows)], - 'value': np.random.randint(1, 100, n_rows), - 'colour': np.random.random(n_rows), - 'cost': np.random.randint(1000, 100000, n_rows), - 'costpp': np.random.randint(100, 10000, n_rows), - 'cost_pp_pa': [str(np.random.randint(100, 10000)) for _ in range(n_rows)], - 'First seen': pd.to_datetime(['2024-01-01'] * n_rows), - 'Last seen': pd.to_datetime(['2024-12-31'] * n_rows), - 'First seen (Parent)': ['2024-01-01'] * n_rows, - 'Last seen (Parent)': ['2024-12-31'] * n_rows, - 'average_spacing': ['Test spacing'] * n_rows, - 'avg_days': pd.to_timedelta([100] * n_rows, unit='D'), - } - sample_df = pd.DataFrame(sample_data) - - gc.collect() - start = time.perf_counter() - fig = create_icicle_figure(sample_df, f"Scale Test {n_rows}") - elapsed = time.perf_counter() - start - - times.append(elapsed) - - # Check that time scaling is roughly linear (not exponential) - # If time doubles when size doubles, it's linear - # We allow some variance, so check that 10x data doesn't take more than 20x time - time_ratio = times[-1] / times[0] - size_ratio = sizes[-1] / sizes[0] - - # Allow 3x the expected linear scaling - max_allowed_ratio = size_ratio * 3 - - assert time_ratio < max_allowed_ratio, ( - f"Figure creation doesn't scale well: " - f"{sizes[-1]} rows took {times[-1]:.3f}s vs {sizes[0]} rows at {times[0]:.3f}s " - f"(ratio {time_ratio:.1f}x, expected <{max_allowed_ratio:.1f}x)" - ) - - print(f"\n Figure scaling: {sizes[0]} rows: {times[0]*1000:.1f}ms, " - f"{sizes[-1]} rows: {times[-1]*1000:.1f}ms (ratio: {time_ratio:.1f}x)") - - -class TestDataVolumeStress: - """Stress tests to verify system handles various data volumes.""" - - @pytest.fixture(autouse=True) - def setup_paths(self): - """Set up paths and verify data exists.""" - from core import default_paths - from data_processing import get_loader - - # Check if database exists - db_path = default_paths.data_dir / "pathways.db" - if not db_path.exists(): - pytest.skip("SQLite database not found") - - self.paths = default_paths - self.loader = get_loader('sqlite') - - # Load data once - result = self.loader.load() - if result is None or result.df is None or len(result.df) == 0: - pytest.skip("No data available in database") - - self.df = result.df - - def test_handles_all_drugs(self): - """Analysis can handle filtering by all drugs.""" - from analysis.pathway_analyzer import prepare_data - import pandas as pd - - all_drugs = self.df['Drug Name'].dropna().unique().tolist() - - # Load org codes - org_codes = pd.read_csv(self.paths.org_codes_csv, index_col=1) - trust_names = org_codes['Name'].tolist()[:5] - - result = prepare_data( - df=self.df, - trust_filter=trust_names, - drug_filter=all_drugs, - directory_filter=self.df['Directory'].dropna().unique().tolist(), - paths=self.paths, - ) - - # Should complete without error (returns tuple) - assert result is not None - assert len(result) == 3 # (df, org_codes, directory_df) - - def test_handles_all_trusts(self): - """Analysis can handle filtering by all trusts.""" - from analysis.pathway_analyzer import prepare_data - import pandas as pd - - # Load org codes - org_codes = pd.read_csv(self.paths.org_codes_csv, index_col=1) - all_trust_names = org_codes['Name'].tolist() - - result = prepare_data( - df=self.df, - trust_filter=all_trust_names, - drug_filter=['ADALIMUMAB', 'ETANERCEPT'], - directory_filter=self.df['Directory'].dropna().unique().tolist(), - paths=self.paths, - ) - - # Should complete without error (returns tuple) - assert result is not None - assert len(result) == 3 # (df, org_codes, directory_df) - - def test_handles_wide_date_range(self): - """Analysis can handle a wide date range via generate_icicle_chart.""" - from analysis.pathway_analyzer import generate_icicle_chart - import pandas as pd - - # Load org codes - org_codes = pd.read_csv(self.paths.org_codes_csv, index_col=1) - trust_names = org_codes['Name'].tolist()[:10] - - # Use very wide date range via full pipeline - ice_df, title = generate_icicle_chart( - df=self.df, - start_date="2010-01-01", - end_date="2030-01-01", - last_seen_date="2010-01-01", - trust_filter=trust_names, - drug_filter=self.df['Drug Name'].dropna().unique().tolist()[:5], - directory_filter=self.df['Directory'].dropna().unique().tolist(), - minimum_num_patients=1, - title="Wide Date Range Test", - paths=self.paths, - ) - - # Should complete without error - assert ice_df is not None or ice_df is None # Just verifying no exception - - def test_handles_minimum_patient_threshold(self): - """Analysis correctly applies minimum patient threshold.""" - from analysis.pathway_analyzer import generate_icicle_chart - import pandas as pd - - # Load org codes - org_codes = pd.read_csv(self.paths.org_codes_csv, index_col=1) - trust_names = org_codes['Name'].tolist()[:10] - - # Run with minimum 50 patients - ice_df_50, _ = generate_icicle_chart( - df=self.df, - start_date="2020-01-01", - end_date="2025-01-01", - last_seen_date="2020-01-01", - trust_filter=trust_names, - drug_filter=self.df['Drug Name'].dropna().unique().tolist()[:5], - directory_filter=self.df['Directory'].dropna().unique().tolist(), - minimum_num_patients=50, - title="Threshold Test 50", - paths=self.paths, - ) - - # Run with minimum 1 patient - ice_df_1, _ = generate_icicle_chart( - df=self.df, - start_date="2020-01-01", - end_date="2025-01-01", - last_seen_date="2020-01-01", - trust_filter=trust_names, - drug_filter=self.df['Drug Name'].dropna().unique().tolist()[:5], - directory_filter=self.df['Directory'].dropna().unique().tolist(), - minimum_num_patients=1, - title="Threshold Test 1", - paths=self.paths, - ) - - # Higher threshold should produce fewer or equal results - len_50 = len(ice_df_50) if ice_df_50 is not None else 0 - len_1 = len(ice_df_1) if ice_df_1 is not None else 0 - - assert len_50 <= len_1, ( - f"Higher minimum threshold should produce fewer results: " - f"min=50 gave {len_50} rows, min=1 gave {len_1} rows" - ) - - -class TestConcurrentOperations: - """Tests for handling multiple operations.""" - - @pytest.fixture(autouse=True) - def setup_paths(self): - """Set up paths and verify data exists.""" - from core import default_paths - from data_processing import get_loader - - # Check if database exists - db_path = default_paths.data_dir / "pathways.db" - if not db_path.exists(): - pytest.skip("SQLite database not found") - - self.paths = default_paths - - def test_multiple_data_loads(self): - """Multiple data loads should not cause issues.""" - from data_processing import get_loader - - results = [] - for i in range(3): - loader = get_loader('sqlite') - result = loader.load() - if result is not None: - results.append(result.row_count) - - # All loads should return same row count - assert len(set(results)) == 1, f"Inconsistent row counts: {results}" - - def test_sequential_analyses(self): - """Multiple sequential analyses should complete.""" - from analysis.pathway_analyzer import generate_icicle_chart - from data_processing import get_loader - import pandas as pd - - # Load data - loader = get_loader('sqlite') - result = loader.load() - if result is None or result.df is None: - pytest.skip("No data available") - - df = result.df - - # Load org codes - org_codes = pd.read_csv(self.paths.org_codes_csv, index_col=1) - trust_names = org_codes['Name'].tolist()[:5] - - # Run multiple analyses - for i in range(3): - ice_df, title = generate_icicle_chart( - df=df, - start_date="2020-01-01", - end_date="2025-01-01", - last_seen_date="2020-01-01", - trust_filter=trust_names, - drug_filter=['ADALIMUMAB'], - directory_filter=df['Directory'].dropna().unique().tolist(), - minimum_num_patients=1, - title=f"Sequential Test {i+1}", - paths=self.paths, - ) - - # Each should complete - assert ice_df is not None or ice_df is None # Just check no error