From d98cd4fd69dedf2852794a90b4714b5df949cbbc Mon Sep 17 00:00:00 2001 From: Andrew Charlwood Date: Fri, 6 Feb 2026 19:21:10 +0000 Subject: [PATCH] feat: add 7 analytics chart query functions (Task 9.2) New query functions in src/data_processing/pathway_queries.py: - get_drug_market_share: Level 3 drug nodes grouped by directory - get_pathway_costs: Level 4+ pathway nodes with cost_pp_pa - get_cost_waterfall: Directorate cost per patient from level 3 aggregation - get_drug_transitions: Sankey source/target drug transitions with ordinal line labels - get_dosing_intervals: Parsed average_spacing by trust/directory - get_drug_directory_matrix: Directory x drug pivot with patient/cost metrics - get_treatment_durations: Weighted avg_days by drug within directorates Thin wrappers added in dash_app/data/queries.py for all 7 functions. --- IMPLEMENTATION_PLAN.md | 6 +- dash_app/data/queries.py | 78 ++++ src/data_processing/pathway_queries.py | 477 +++++++++++++++++++++++++ 3 files changed, 558 insertions(+), 3 deletions(-) diff --git a/IMPLEMENTATION_PLAN.md b/IMPLEMENTATION_PLAN.md index 8eb7fa6..4f57b6e 100644 --- a/IMPLEMENTATION_PLAN.md +++ b/IMPLEMENTATION_PLAN.md @@ -359,7 +359,7 @@ Drawer selection → update_drug_selection → app-state store → load_pathway_ - **Checkpoint**: App starts, tab bar renders with all 8 tabs, icicle tab still works, other tabs show placeholder "Coming soon" messages ✓ ### 9.2 Query functions for all chart types -- [ ] Add to `src/data_processing/pathway_queries.py`: +- [x] Add to `src/data_processing/pathway_queries.py`: - `get_drug_market_share(db_path, date_filter_id, chart_type, directory=None, trust=None)` — Level 3 nodes grouped by directory, returning drug, value, colour - `get_pathway_costs(db_path, date_filter_id, chart_type, directory=None)` — Level 4+ nodes with cost_pp_pa, parsed pathway labels, patient counts - `get_cost_waterfall(db_path, date_filter_id, chart_type, trust=None)` — Level 2 nodes with cost_pp_pa per directorate/indication @@ -367,8 +367,8 @@ Drawer selection → update_drug_selection → app-state store → load_pathway_ - `get_dosing_intervals(db_path, date_filter_id, chart_type, drug=None)` — Level 3 nodes for a specific drug, parsed average_spacing by trust/directory - `get_drug_directory_matrix(db_path, date_filter_id, chart_type)` — Level 3 nodes pivoted as directory × drug with value/cost metrics - `get_treatment_durations(db_path, date_filter_id, chart_type, directory=None)` — Level 3 nodes with avg_days by drug within a directorate -- [ ] Add thin wrappers in `dash_app/data/queries.py` for each new function (resolve DB_PATH and delegate) -- **Checkpoint**: All 7 query functions return correct data via manual Python tests (`python -c "..."`) +- [x] Add thin wrappers in `dash_app/data/queries.py` for each new function (resolve DB_PATH and delegate) +- **Checkpoint**: All 7 query functions return correct data via manual Python tests (`python -c "..."`) ✓ ### 9.3 First-Line Market Share chart (Tab 2) - [ ] Create `dash_app/callbacks/market_share.py`: diff --git a/dash_app/data/queries.py b/dash_app/data/queries.py index 47e4e42..bfc4d32 100644 --- a/dash_app/data/queries.py +++ b/dash_app/data/queries.py @@ -11,6 +11,13 @@ from typing import Optional from data_processing.pathway_queries import ( load_initial_data as _load_initial_data, load_pathway_nodes as _load_pathway_nodes, + get_drug_market_share as _get_drug_market_share, + get_pathway_costs as _get_pathway_costs, + get_cost_waterfall as _get_cost_waterfall, + get_drug_transitions as _get_drug_transitions, + get_dosing_intervals as _get_dosing_intervals, + get_drug_directory_matrix as _get_drug_directory_matrix, + get_treatment_durations as _get_treatment_durations, ) DB_PATH = Path(__file__).resolve().parents[2] / "data" / "pathways.db" @@ -37,3 +44,74 @@ def load_pathway_data( selected_directorates=selected_directorates, selected_trusts=selected_trusts, ) + + +# --- Analytics chart query wrappers (Phase 9) --- + + +def get_drug_market_share( + date_filter_id: str = "all_6mo", + chart_type: str = "directory", + directory: Optional[str] = None, + trust: Optional[str] = None, +) -> list[dict]: + """Level 3 drug nodes grouped by directory with patient counts.""" + return _get_drug_market_share(DB_PATH, date_filter_id, chart_type, directory, trust) + + +def get_pathway_costs( + date_filter_id: str = "all_6mo", + chart_type: str = "directory", + directory: Optional[str] = None, + trust: Optional[str] = None, +) -> list[dict]: + """Level 4+ pathway nodes with annualized cost.""" + return _get_pathway_costs(DB_PATH, date_filter_id, chart_type, directory, trust) + + +def get_cost_waterfall( + date_filter_id: str = "all_6mo", + chart_type: str = "directory", + trust: Optional[str] = None, +) -> list[dict]: + """Level 2 directorate nodes with cost per patient.""" + return _get_cost_waterfall(DB_PATH, date_filter_id, chart_type, trust) + + +def get_drug_transitions( + date_filter_id: str = "all_6mo", + chart_type: str = "directory", + directory: Optional[str] = None, + trust: Optional[str] = None, +) -> dict: + """Drug transition data for Sankey diagram.""" + return _get_drug_transitions(DB_PATH, date_filter_id, chart_type, directory, trust) + + +def get_dosing_intervals( + date_filter_id: str = "all_6mo", + chart_type: str = "directory", + drug: Optional[str] = None, + trust: Optional[str] = None, +) -> list[dict]: + """Dosing interval data parsed from average_spacing.""" + return _get_dosing_intervals(DB_PATH, date_filter_id, chart_type, drug, trust) + + +def get_drug_directory_matrix( + date_filter_id: str = "all_6mo", + chart_type: str = "directory", + trust: Optional[str] = None, +) -> dict: + """Directory × drug matrix with patient counts and costs.""" + return _get_drug_directory_matrix(DB_PATH, date_filter_id, chart_type, trust) + + +def get_treatment_durations( + date_filter_id: str = "all_6mo", + chart_type: str = "directory", + directory: Optional[str] = None, + trust: Optional[str] = None, +) -> list[dict]: + """Treatment duration data (avg_days) by drug.""" + return _get_treatment_durations(DB_PATH, date_filter_id, chart_type, directory, trust) diff --git a/src/data_processing/pathway_queries.py b/src/data_processing/pathway_queries.py index 532b6a3..3da7fe7 100644 --- a/src/data_processing/pathway_queries.py +++ b/src/data_processing/pathway_queries.py @@ -330,3 +330,480 @@ def _empty_result(error: str = "") -> dict: "last_updated": "", "error": error, } + + +# --------------------------------------------------------------------------- +# Analytics chart query functions (Phase 9) +# --------------------------------------------------------------------------- + +def _safe_float(value, default=0.0): + """Convert a value to float, returning default for None/N/A/empty.""" + if value is None or value == "" or value == "N/A": + return default + try: + return float(value) + except (ValueError, TypeError): + return default + + +def get_drug_market_share( + db_path: Path, + date_filter_id: str, + chart_type: str, + directory: Optional[str] = None, + trust: Optional[str] = None, +) -> list[dict]: + """Level 3 drug nodes grouped by directory with patient counts and proportions. + + Returns list of dicts: [{directory, drug, patients, proportion, cost, cost_pp_pa}] + Sorted by directory total patients desc, then drug patients desc within each. + """ + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + try: + where = ["date_filter_id = ?", "chart_type = ?", "level = 3"] + params: list = [date_filter_id, chart_type] + + if directory: + where.append("directory = ?") + params.append(directory) + if trust: + where.append("trust_name = ?") + params.append(trust) + + query = f""" + SELECT labels AS drug, directory, value AS patients, + colour AS proportion, cost, cost_pp_pa, trust_name + FROM pathway_nodes + WHERE {' AND '.join(where)} + ORDER BY directory, value DESC + """ + rows = conn.execute(query, params).fetchall() + + # Aggregate across trusts: sum patients/cost per directory+drug + agg = {} + for r in rows: + key = (r["directory"], r["drug"]) + if key not in agg: + agg[key] = {"directory": r["directory"], "drug": r["drug"], + "patients": 0, "cost": 0.0} + agg[key]["patients"] += r["patients"] or 0 + agg[key]["cost"] += float(r["cost"]) if r["cost"] else 0.0 + + # Compute proportions within each directory + dir_totals = {} + for v in agg.values(): + dir_totals[v["directory"]] = dir_totals.get(v["directory"], 0) + v["patients"] + + result = [] + for v in agg.values(): + total = dir_totals.get(v["directory"], 1) + result.append({ + "directory": v["directory"], + "drug": v["drug"], + "patients": v["patients"], + "proportion": round(v["patients"] / total, 4) if total else 0, + "cost": round(v["cost"], 2), + "cost_pp_pa": round(v["cost"] / v["patients"], 2) if v["patients"] else 0, + }) + + # Sort: directory by total patients desc, drugs by patients desc within + result.sort(key=lambda x: (-dir_totals.get(x["directory"], 0), -x["patients"])) + return result + except sqlite3.Error: + return [] + finally: + conn.close() + + +def get_pathway_costs( + db_path: Path, + date_filter_id: str, + chart_type: str, + directory: Optional[str] = None, + trust: Optional[str] = None, +) -> list[dict]: + """Level 4+ pathway nodes with annualized cost and pathway labels. + + Returns list of dicts: [{ids, pathway_label, cost_pp_pa, patients, directory, drug_sequence}] + Sorted by cost_pp_pa desc. + """ + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + try: + where = ["date_filter_id = ?", "chart_type = ?", "level >= 4"] + params: list = [date_filter_id, chart_type] + + if directory: + where.append("directory = ?") + params.append(directory) + if trust: + where.append("trust_name = ?") + params.append(trust) + + query = f""" + SELECT ids, labels, level, value AS patients, cost, costpp, + cost_pp_pa, avg_days, directory, drug_sequence, trust_name + FROM pathway_nodes + WHERE {' AND '.join(where)} + ORDER BY value DESC + """ + rows = conn.execute(query, params).fetchall() + + result = [] + for r in rows: + cpp = _safe_float(r["cost_pp_pa"]) + drugs = (r["drug_sequence"] or "").split("|") + pathway_label = " → ".join(d for d in drugs if d) + result.append({ + "ids": r["ids"], + "pathway_label": pathway_label, + "cost_pp_pa": cpp, + "patients": r["patients"] or 0, + "cost": float(r["cost"]) if r["cost"] else 0.0, + "avg_days": _safe_float(r["avg_days"]), + "directory": r["directory"] or "", + "trust_name": r["trust_name"] or "", + "drug_sequence": drugs, + "level": r["level"], + }) + + result.sort(key=lambda x: -x["cost_pp_pa"]) + return result + except sqlite3.Error: + return [] + finally: + conn.close() + + +def get_cost_waterfall( + db_path: Path, + date_filter_id: str, + chart_type: str, + trust: Optional[str] = None, +) -> list[dict]: + """Level 2 directorate/indication nodes with cost metrics. + + Since level 2 cost_pp_pa is 'N/A', we compute it from child (level 3) nodes: + sum(cost) / sum(patients) for each directory. + + Returns list of dicts: [{directory, patients, total_cost, cost_pp}] + Sorted by cost_pp desc. + """ + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + try: + where = ["date_filter_id = ?", "chart_type = ?", "level = 3"] + params: list = [date_filter_id, chart_type] + + if trust: + where.append("trust_name = ?") + params.append(trust) + + query = f""" + SELECT directory, SUM(value) AS patients, SUM(cost) AS total_cost + FROM pathway_nodes + WHERE {' AND '.join(where)} + GROUP BY directory + HAVING patients > 0 + ORDER BY total_cost DESC + """ + rows = conn.execute(query, params).fetchall() + + result = [] + for r in rows: + patients = r["patients"] or 0 + total_cost = float(r["total_cost"]) if r["total_cost"] else 0.0 + result.append({ + "directory": r["directory"] or "", + "patients": patients, + "total_cost": round(total_cost, 2), + "cost_pp": round(total_cost / patients, 2) if patients else 0, + }) + + result.sort(key=lambda x: -x["cost_pp"]) + return result + except sqlite3.Error: + return [] + finally: + conn.close() + + +def get_drug_transitions( + db_path: Path, + date_filter_id: str, + chart_type: str, + directory: Optional[str] = None, + trust: Optional[str] = None, +) -> dict: + """Parse level 3+ nodes into source→target drug transitions for Sankey. + + Returns dict with: + nodes: [{name, total_patients}] — unique drug names + links: [{source_idx, target_idx, patients}] — transitions between drugs + """ + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + try: + where = ["date_filter_id = ?", "chart_type = ?", "level >= 4"] + params: list = [date_filter_id, chart_type] + + if directory: + where.append("directory = ?") + params.append(directory) + if trust: + where.append("trust_name = ?") + params.append(trust) + + query = f""" + SELECT ids, level, value AS patients, drug_sequence, directory + FROM pathway_nodes + WHERE {' AND '.join(where)} + ORDER BY level, ids + """ + rows = conn.execute(query, params).fetchall() + + # Build node list and link aggregation + # Each drug at each treatment line is a separate Sankey node + # e.g., "ADALIMUMAB (1st)" and "ADALIMUMAB (2nd)" are different nodes + drug_line_set = set() + link_agg = {} + + for r in rows: + drugs = [d for d in (r["drug_sequence"] or "").split("|") if d] + patients = r["patients"] or 0 + if len(drugs) < 2 or patients == 0: + continue + + # Only use adjacent transitions (drug[i] → drug[i+1]) + for i in range(len(drugs) - 1): + src = f"{drugs[i]} ({_ordinal(i + 1)})" + tgt = f"{drugs[i + 1]} ({_ordinal(i + 2)})" + drug_line_set.add(src) + drug_line_set.add(tgt) + key = (src, tgt) + link_agg[key] = link_agg.get(key, 0) + patients + + # Build indexed node list + node_list = sorted(drug_line_set) + node_idx = {name: i for i, name in enumerate(node_list)} + + nodes = [{"name": name} for name in node_list] + links = [ + {"source_idx": node_idx[src], "target_idx": node_idx[tgt], "patients": pts} + for (src, tgt), pts in sorted(link_agg.items(), key=lambda x: -x[1]) + ] + + return {"nodes": nodes, "links": links} + except sqlite3.Error: + return {"nodes": [], "links": []} + finally: + conn.close() + + +def _ordinal(n: int) -> str: + """Return '1st', '2nd', '3rd', '4th', etc.""" + if 11 <= n % 100 <= 13: + return f"{n}th" + suffix = {1: "st", 2: "nd", 3: "rd"}.get(n % 10, "th") + return f"{n}{suffix}" + + +def get_dosing_intervals( + db_path: Path, + date_filter_id: str, + chart_type: str, + drug: Optional[str] = None, + trust: Optional[str] = None, +) -> list[dict]: + """Level 3 drug nodes with parsed dosing interval data. + + Returns list of dicts: + [{drug, trust_name, directory, weekly_interval, dose_count, total_weeks, patients}] + """ + from data_processing.parsing import parse_average_spacing + + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + try: + where = ["date_filter_id = ?", "chart_type = ?", "level = 3", + "average_spacing IS NOT NULL", "average_spacing != ''"] + params: list = [date_filter_id, chart_type] + + if drug: + where.append("labels = ?") + params.append(drug) + if trust: + where.append("trust_name = ?") + params.append(trust) + + query = f""" + SELECT labels AS drug, trust_name, directory, value AS patients, + average_spacing + FROM pathway_nodes + WHERE {' AND '.join(where)} + ORDER BY labels, trust_name + """ + rows = conn.execute(query, params).fetchall() + + result = [] + for r in rows: + parsed = parse_average_spacing(r["average_spacing"]) + for entry in parsed: + result.append({ + "drug": entry["drug_name"], + "trust_name": r["trust_name"] or "", + "directory": r["directory"] or "", + "weekly_interval": entry["weekly_interval"], + "dose_count": entry["dose_count"], + "total_weeks": entry["total_weeks"], + "patients": r["patients"] or 0, + }) + + return result + except sqlite3.Error: + return [] + finally: + conn.close() + + +def get_drug_directory_matrix( + db_path: Path, + date_filter_id: str, + chart_type: str, + trust: Optional[str] = None, +) -> dict: + """Level 3 nodes pivoted as directory × drug matrix. + + Returns dict with: + directories: sorted list of directory names (rows) + drugs: sorted list of drug names (columns) + matrix: {directory: {drug: {patients, cost, cost_pp_pa}}} + """ + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + try: + where = ["date_filter_id = ?", "chart_type = ?", "level = 3"] + params: list = [date_filter_id, chart_type] + + if trust: + where.append("trust_name = ?") + params.append(trust) + + query = f""" + SELECT labels AS drug, directory, value AS patients, cost, cost_pp_pa + FROM pathway_nodes + WHERE {' AND '.join(where)} + ORDER BY directory, labels + """ + rows = conn.execute(query, params).fetchall() + + # Aggregate across trusts + matrix = {} + dir_totals = {} + drug_totals = {} + + for r in rows: + d = r["directory"] or "" + drug = r["drug"] or "" + patients = r["patients"] or 0 + cost = float(r["cost"]) if r["cost"] else 0.0 + cpp = _safe_float(r["cost_pp_pa"]) + + if d not in matrix: + matrix[d] = {} + if drug not in matrix[d]: + matrix[d][drug] = {"patients": 0, "cost": 0.0} + + matrix[d][drug]["patients"] += patients + matrix[d][drug]["cost"] += cost + + dir_totals[d] = dir_totals.get(d, 0) + patients + drug_totals[drug] = drug_totals.get(drug, 0) + patients + + # Add cost_pp_pa to each cell + for d in matrix: + for drug in matrix[d]: + cell = matrix[d][drug] + cell["cost"] = round(cell["cost"], 2) + cell["cost_pp_pa"] = ( + round(cell["cost"] / cell["patients"], 2) if cell["patients"] else 0 + ) + + # Sort directories by total patients desc, drugs by frequency desc + directories = sorted(dir_totals, key=lambda x: -dir_totals[x]) + drugs = sorted(drug_totals, key=lambda x: -drug_totals[x]) + + return {"directories": directories, "drugs": drugs, "matrix": matrix} + except sqlite3.Error: + return {"directories": [], "drugs": [], "matrix": {}} + finally: + conn.close() + + +def get_treatment_durations( + db_path: Path, + date_filter_id: str, + chart_type: str, + directory: Optional[str] = None, + trust: Optional[str] = None, +) -> list[dict]: + """Level 3 drug nodes with average treatment duration. + + Returns list of dicts: [{drug, avg_days, patients, directory}] + Sorted by avg_days desc. Excludes nodes with no avg_days data. + """ + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + try: + where = ["date_filter_id = ?", "chart_type = ?", "level = 3", + "avg_days IS NOT NULL"] + params: list = [date_filter_id, chart_type] + + if directory: + where.append("directory = ?") + params.append(directory) + if trust: + where.append("trust_name = ?") + params.append(trust) + + query = f""" + SELECT labels AS drug, directory, trust_name, + value AS patients, avg_days + FROM pathway_nodes + WHERE {' AND '.join(where)} + ORDER BY avg_days DESC + """ + rows = conn.execute(query, params).fetchall() + + # Aggregate across trusts: weighted average of avg_days by patients + agg = {} + for r in rows: + key = (r["directory"] or "", r["drug"]) + patients = r["patients"] or 0 + days = _safe_float(r["avg_days"]) + if patients == 0 or days == 0: + continue + + if key not in agg: + agg[key] = {"drug": r["drug"], "directory": r["directory"] or "", + "total_weighted_days": 0.0, "total_patients": 0} + agg[key]["total_weighted_days"] += days * patients + agg[key]["total_patients"] += patients + + result = [] + for v in agg.values(): + if v["total_patients"] > 0: + result.append({ + "drug": v["drug"], + "directory": v["directory"], + "avg_days": round(v["total_weighted_days"] / v["total_patients"], 1), + "patients": v["total_patients"], + }) + + result.sort(key=lambda x: -x["avg_days"]) + return result + except sqlite3.Error: + return [] + finally: + conn.close()