""" Diagnosis lookup module for NHS Patient Pathway Analysis. Provides functions to validate patient indications by checking GP diagnosis records against SNOMED cluster codes. Uses the drug-to-cluster mapping from drug_indication_clusters.csv and queries Snowflake for SNOMED codes and GP records. Key workflow: 1. Get drug's valid indication clusters from local mapping 2. Get all SNOMED codes for those clusters from Snowflake 3. Check if patient has any of those SNOMED codes in GP records 4. Report indication validation status IMPORTANT: HCD activity data indication codes are UNRELIABLE. This module uses GP/Primary Care data (PrimaryCareClinicalCoding) as the authoritative source. """ from dataclasses import dataclass, field from datetime import date, datetime from pathlib import Path from typing import Optional, Callable, Any, cast import csv from core.logging_config import get_logger from data_processing.database import DatabaseManager, default_db_manager from data_processing.snowflake_connector import ( SnowflakeConnector, get_connector, is_snowflake_available, is_snowflake_configured, SNOWFLAKE_AVAILABLE, ) from data_processing.cache import get_cache, is_cache_enabled logger = get_logger(__name__) @dataclass class ClusterSnomedCodes: """SNOMED codes for a clinical coding cluster.""" cluster_id: str cluster_description: str snomed_codes: list[str] = field(default_factory=list) snomed_descriptions: dict[str, str] = field(default_factory=dict) @property def code_count(self) -> int: return len(self.snomed_codes) @dataclass class IndicationValidationResult: """Result of validating a patient's indication for a drug.""" patient_pseudonym: str drug_name: str has_valid_indication: bool matched_cluster_id: Optional[str] = None matched_snomed_code: Optional[str] = None matched_snomed_description: Optional[str] = None checked_clusters: list[str] = field(default_factory=list) total_codes_checked: int = 0 source: str = "GP_SNOMED" # GP_SNOMED | NONE error_message: Optional[str] = None @dataclass class DrugIndicationMatchRate: """Match rate statistics for a drug's indication validation.""" drug_name: str total_patients: int patients_with_indication: int patients_without_indication: int match_rate: float # 0.0 to 1.0 clusters_checked: list[str] = field(default_factory=list) sample_unmatched: list[str] = field(default_factory=list) # Sample patient IDs def get_drug_clusters( drug_name: str, db_manager: Optional[DatabaseManager] = None ) -> list[dict]: """ Get all SNOMED cluster mappings for a drug from local SQLite. Args: drug_name: Drug name to look up (case-insensitive) db_manager: Optional DatabaseManager (defaults to default_db_manager) Returns: List of dicts with keys: drug_name, indication, cluster_id, cluster_description, nice_ta_reference """ if db_manager is None: db_manager = default_db_manager query = """ SELECT drug_name, indication, cluster_id, cluster_description, nice_ta_reference FROM ref_drug_indication_clusters WHERE UPPER(drug_name) = UPPER(?) ORDER BY indication, cluster_id """ try: with db_manager.get_connection() as conn: cursor = conn.execute(query, (drug_name,)) rows = cursor.fetchall() results = [] for row in rows: results.append({ "drug_name": row["drug_name"], "indication": row["indication"], "cluster_id": row["cluster_id"], "cluster_description": row["cluster_description"], "nice_ta_reference": row["nice_ta_reference"], }) logger.debug(f"Found {len(results)} cluster mappings for drug '{drug_name}'") return results except Exception as e: logger.error(f"Error getting clusters for drug '{drug_name}': {e}") return [] def get_drug_cluster_ids( drug_name: str, db_manager: Optional[DatabaseManager] = None ) -> list[str]: """ Get unique cluster IDs for a drug. Args: drug_name: Drug name to look up db_manager: Optional DatabaseManager Returns: List of unique cluster IDs """ clusters = get_drug_clusters(drug_name, db_manager) return list(set(c["cluster_id"] for c in clusters)) def get_cluster_snomed_codes( cluster_id: str, connector: Optional[SnowflakeConnector] = None, use_cache: bool = True, ) -> ClusterSnomedCodes: """ Get all SNOMED codes for a cluster from Snowflake. Queries the ClinicalCodingClusterSnomedCodes table to get all SNOMED codes that belong to the specified cluster. Args: cluster_id: Cluster ID to look up (e.g., 'RARTH_COD', 'PSORIASIS_COD') connector: Optional SnowflakeConnector (defaults to singleton) use_cache: Whether to use cached results (default True) Returns: ClusterSnomedCodes with list of SNOMED codes and descriptions """ if not SNOWFLAKE_AVAILABLE: logger.warning("Snowflake connector not available") return ClusterSnomedCodes(cluster_id=cluster_id, cluster_description="") if not is_snowflake_configured(): logger.warning("Snowflake not configured - cannot get cluster codes") return ClusterSnomedCodes(cluster_id=cluster_id, cluster_description="") # Check cache first cache_key = f"cluster_snomed_{cluster_id}" if use_cache and is_cache_enabled(): cache = get_cache() cached = cache.get(cache_key) if cached is not None and len(cached) > 0: logger.debug(f"Using cached SNOMED codes for cluster '{cluster_id}'") cached_dict = cached[0] # First element is our data dict return ClusterSnomedCodes( cluster_id=cluster_id, cluster_description=str(cached_dict.get("description", "")), snomed_codes=list(cached_dict.get("codes", [])), snomed_descriptions=dict(cached_dict.get("descriptions", {})), ) if connector is None: connector = get_connector() query = ''' SELECT DISTINCT "Cluster_ID", "Cluster_Description", "SNOMEDCode", "SNOMEDDescription" FROM DATA_HUB.PHM."ClinicalCodingClusterSnomedCodes" WHERE "Cluster_ID" = %s ORDER BY "SNOMEDCode" ''' try: results = connector.execute_dict(query, (cluster_id,)) if not results: logger.warning(f"No SNOMED codes found for cluster '{cluster_id}'") return ClusterSnomedCodes(cluster_id=cluster_id, cluster_description="") codes = [] descriptions = {} description = results[0].get("Cluster_Description", "") if results else "" for row in results: code = row.get("SNOMEDCode") if code: codes.append(code) descriptions[code] = row.get("SNOMEDDescription", "") logger.info(f"Found {len(codes)} SNOMED codes for cluster '{cluster_id}'") # Cache the results (using query-based cache with fake params) if use_cache and is_cache_enabled(): cache = get_cache() cache_data = [{ "description": description, "codes": codes, "descriptions": descriptions, }] cache.set(cache_key, None, cache_data) # type: ignore[arg-type] return ClusterSnomedCodes( cluster_id=cluster_id, cluster_description=description, snomed_codes=codes, snomed_descriptions=descriptions, ) except Exception as e: logger.error(f"Error getting SNOMED codes for cluster '{cluster_id}': {e}") return ClusterSnomedCodes(cluster_id=cluster_id, cluster_description="") def patient_has_indication( patient_pseudonym: str, cluster_ids: list[str], connector: Optional[SnowflakeConnector] = None, before_date: Optional[date] = None, ) -> tuple[bool, Optional[str], Optional[str], Optional[str]]: """ Check if a patient has any SNOMED codes from the specified clusters in GP records. Args: patient_pseudonym: Patient's pseudonymised NHS number cluster_ids: List of cluster IDs to check against connector: Optional SnowflakeConnector before_date: Optional date - only check diagnoses before this date Returns: Tuple of (has_indication, matched_cluster_id, matched_snomed_code, matched_description) """ if not SNOWFLAKE_AVAILABLE or not is_snowflake_configured(): return False, None, None, None if not cluster_ids: return False, None, None, None if connector is None: connector = get_connector() # Build placeholders for cluster IDs placeholders = ", ".join(["%s"] * len(cluster_ids)) # Query to check if patient has any matching SNOMED code query = f''' SELECT pc."SNOMEDCode", cc."Cluster_ID", cc."SNOMEDDescription" FROM DATA_HUB.PHM."PrimaryCareClinicalCoding" pc INNER JOIN DATA_HUB.PHM."ClinicalCodingClusterSnomedCodes" cc ON pc."SNOMEDCode" = cc."SNOMEDCode" WHERE pc."PatientPseudonym" = %s AND cc."Cluster_ID" IN ({placeholders}) ''' params = [patient_pseudonym] + cluster_ids if before_date: query += ' AND pc."EventDateTime" < %s' params.append(before_date.isoformat()) query += ' LIMIT 1' try: results = connector.execute_dict(query, tuple(params)) if results: row = results[0] return ( True, row.get("Cluster_ID"), row.get("SNOMEDCode"), row.get("SNOMEDDescription"), ) return False, None, None, None except Exception as e: logger.error(f"Error checking indication for patient '{patient_pseudonym}': {e}") return False, None, None, None def validate_indication( patient_pseudonym: str, drug_name: str, connector: Optional[SnowflakeConnector] = None, db_manager: Optional[DatabaseManager] = None, before_date: Optional[date] = None, ) -> IndicationValidationResult: """ Validate that a patient has an appropriate indication for a drug. Full validation workflow: 1. Get drug's valid indication clusters from local mapping 2. Check if patient has any matching SNOMED codes in GP records 3. Return detailed validation result Args: patient_pseudonym: Patient's pseudonymised NHS number drug_name: Drug name to validate indication for connector: Optional SnowflakeConnector db_manager: Optional DatabaseManager before_date: Optional date - only check diagnoses before this date Returns: IndicationValidationResult with validation details """ result = IndicationValidationResult( patient_pseudonym=patient_pseudonym, drug_name=drug_name, has_valid_indication=False, ) # Step 1: Get drug's cluster mappings cluster_ids = get_drug_cluster_ids(drug_name, db_manager) if not cluster_ids: result.error_message = f"No cluster mappings found for drug '{drug_name}'" result.source = "NONE" return result result.checked_clusters = cluster_ids # Step 2: Check Snowflake availability if not SNOWFLAKE_AVAILABLE: result.error_message = "Snowflake connector not installed" result.source = "NONE" return result if not is_snowflake_configured(): result.error_message = "Snowflake not configured" result.source = "NONE" return result # Step 3: Check patient GP records has_indication, matched_cluster, matched_code, matched_desc = patient_has_indication( patient_pseudonym=patient_pseudonym, cluster_ids=cluster_ids, connector=connector, before_date=before_date, ) result.has_valid_indication = has_indication result.matched_cluster_id = matched_cluster result.matched_snomed_code = matched_code result.matched_snomed_description = matched_desc result.source = "GP_SNOMED" if has_indication else "NONE" return result def get_indication_match_rate( drug_name: str, patient_pseudonyms: list[str], connector: Optional[SnowflakeConnector] = None, db_manager: Optional[DatabaseManager] = None, sample_unmatched_count: int = 10, ) -> DrugIndicationMatchRate: """ Calculate indication match rate for a drug across a list of patients. Args: drug_name: Drug name to check patient_pseudonyms: List of patient pseudonymised NHS numbers connector: Optional SnowflakeConnector db_manager: Optional DatabaseManager sample_unmatched_count: Number of unmatched patient IDs to include in sample Returns: DrugIndicationMatchRate with match statistics """ if connector is None and SNOWFLAKE_AVAILABLE and is_snowflake_configured(): connector = get_connector() cluster_ids = get_drug_cluster_ids(drug_name, db_manager) total = len(patient_pseudonyms) matched = 0 unmatched = 0 sample_unmatched: list[str] = [] if not cluster_ids: logger.warning(f"No cluster mappings for drug '{drug_name}' - all patients will be unmatched") return DrugIndicationMatchRate( drug_name=drug_name, total_patients=total, patients_with_indication=0, patients_without_indication=total, match_rate=0.0, clusters_checked=[], sample_unmatched=patient_pseudonyms[:sample_unmatched_count], ) for i, pseudonym in enumerate(patient_pseudonyms): if i > 0 and i % 100 == 0: logger.info(f"Validating indications: {i}/{total} ({100*i/total:.1f}%)") has_indication, _, _, _ = patient_has_indication( patient_pseudonym=pseudonym, cluster_ids=cluster_ids, connector=connector, ) if has_indication: matched += 1 else: unmatched += 1 if len(sample_unmatched) < sample_unmatched_count: sample_unmatched.append(pseudonym) match_rate = matched / total if total > 0 else 0.0 logger.info(f"Indication match rate for '{drug_name}': {100*match_rate:.1f}% ({matched}/{total})") return DrugIndicationMatchRate( drug_name=drug_name, total_patients=total, patients_with_indication=matched, patients_without_indication=unmatched, match_rate=match_rate, clusters_checked=cluster_ids, sample_unmatched=sample_unmatched, ) def batch_validate_indications( patient_drug_pairs: list[tuple[str, str]], connector: Optional[SnowflakeConnector] = None, db_manager: Optional[DatabaseManager] = None, progress_callback: Optional[Callable[[int, int], None]] = None, ) -> list[IndicationValidationResult]: """ Validate indications for multiple patient-drug pairs efficiently. Args: patient_drug_pairs: List of (patient_pseudonym, drug_name) tuples connector: Optional SnowflakeConnector db_manager: Optional DatabaseManager progress_callback: Optional callback(current, total) for progress updates Returns: List of IndicationValidationResult for each pair """ results = [] total = len(patient_drug_pairs) # Cache cluster lookups by drug drug_clusters_cache = {} for i, (pseudonym, drug_name) in enumerate(patient_drug_pairs): if progress_callback: progress_callback(i + 1, total) # Get clusters from cache or lookup drug_upper = drug_name.upper() if drug_upper not in drug_clusters_cache: drug_clusters_cache[drug_upper] = get_drug_cluster_ids(drug_name, db_manager) cluster_ids = drug_clusters_cache[drug_upper] if not cluster_ids: results.append(IndicationValidationResult( patient_pseudonym=pseudonym, drug_name=drug_name, has_valid_indication=False, source="NONE", error_message=f"No cluster mappings for drug '{drug_name}'", )) continue # Check patient indication has_indication, matched_cluster, matched_code, matched_desc = patient_has_indication( patient_pseudonym=pseudonym, cluster_ids=cluster_ids, connector=connector, ) results.append(IndicationValidationResult( patient_pseudonym=pseudonym, drug_name=drug_name, has_valid_indication=has_indication, matched_cluster_id=matched_cluster, matched_snomed_code=matched_code, matched_snomed_description=matched_desc, checked_clusters=cluster_ids, source="GP_SNOMED" if has_indication else "NONE", )) matched_count = sum(1 for r in results if r.has_valid_indication) logger.info(f"Batch validation complete: {matched_count}/{total} ({100*matched_count/total:.1f}%) with valid indications") return results def get_available_clusters( connector: Optional[SnowflakeConnector] = None, ) -> list[dict]: """ Get list of all available SNOMED clusters from Snowflake. Returns: List of dicts with cluster_id, cluster_description, code_count """ if not SNOWFLAKE_AVAILABLE or not is_snowflake_configured(): logger.warning("Snowflake not available - cannot list clusters") return [] if connector is None: connector = get_connector() query = ''' SELECT "Cluster_ID", "Cluster_Description", COUNT(DISTINCT "SNOMEDCode") as code_count FROM DATA_HUB.PHM."ClinicalCodingClusterSnomedCodes" GROUP BY "Cluster_ID", "Cluster_Description" ORDER BY "Cluster_ID" ''' try: results = connector.execute_dict(query) clusters = [] for row in results: clusters.append({ "cluster_id": row.get("Cluster_ID"), "cluster_description": row.get("Cluster_Description"), "code_count": row.get("code_count", 0), }) logger.info(f"Found {len(clusters)} available SNOMED clusters") return clusters except Exception as e: logger.error(f"Error getting available clusters: {e}") return [] # Export public API __all__ = [ "ClusterSnomedCodes", "IndicationValidationResult", "DrugIndicationMatchRate", "get_drug_clusters", "get_drug_cluster_ids", "get_cluster_snomed_codes", "patient_has_indication", "validate_indication", "get_indication_match_rate", "batch_validate_indications", "get_available_clusters", ]