feat: add 7 analytics chart query functions (Task 9.2)

New query functions in src/data_processing/pathway_queries.py:
- get_drug_market_share: Level 3 drug nodes grouped by directory
- get_pathway_costs: Level 4+ pathway nodes with cost_pp_pa
- get_cost_waterfall: Directorate cost per patient from level 3 aggregation
- get_drug_transitions: Sankey source/target drug transitions with ordinal line labels
- get_dosing_intervals: Parsed average_spacing by trust/directory
- get_drug_directory_matrix: Directory x drug pivot with patient/cost metrics
- get_treatment_durations: Weighted avg_days by drug within directorates

Thin wrappers added in dash_app/data/queries.py for all 7 functions.
This commit is contained in:
Andrew Charlwood
2026-02-06 19:21:10 +00:00
parent b34a1138fc
commit d98cd4fd69
3 changed files with 558 additions and 3 deletions
+3 -3
View File
@@ -359,7 +359,7 @@ Drawer selection → update_drug_selection → app-state store → load_pathway_
- **Checkpoint**: App starts, tab bar renders with all 8 tabs, icicle tab still works, other tabs show placeholder "Coming soon" messages ✓ - **Checkpoint**: App starts, tab bar renders with all 8 tabs, icicle tab still works, other tabs show placeholder "Coming soon" messages ✓
### 9.2 Query functions for all chart types ### 9.2 Query functions for all chart types
- [ ] Add to `src/data_processing/pathway_queries.py`: - [x] Add to `src/data_processing/pathway_queries.py`:
- `get_drug_market_share(db_path, date_filter_id, chart_type, directory=None, trust=None)` — Level 3 nodes grouped by directory, returning drug, value, colour - `get_drug_market_share(db_path, date_filter_id, chart_type, directory=None, trust=None)` — Level 3 nodes grouped by directory, returning drug, value, colour
- `get_pathway_costs(db_path, date_filter_id, chart_type, directory=None)` — Level 4+ nodes with cost_pp_pa, parsed pathway labels, patient counts - `get_pathway_costs(db_path, date_filter_id, chart_type, directory=None)` — Level 4+ nodes with cost_pp_pa, parsed pathway labels, patient counts
- `get_cost_waterfall(db_path, date_filter_id, chart_type, trust=None)` — Level 2 nodes with cost_pp_pa per directorate/indication - `get_cost_waterfall(db_path, date_filter_id, chart_type, trust=None)` — Level 2 nodes with cost_pp_pa per directorate/indication
@@ -367,8 +367,8 @@ Drawer selection → update_drug_selection → app-state store → load_pathway_
- `get_dosing_intervals(db_path, date_filter_id, chart_type, drug=None)` — Level 3 nodes for a specific drug, parsed average_spacing by trust/directory - `get_dosing_intervals(db_path, date_filter_id, chart_type, drug=None)` — Level 3 nodes for a specific drug, parsed average_spacing by trust/directory
- `get_drug_directory_matrix(db_path, date_filter_id, chart_type)` — Level 3 nodes pivoted as directory × drug with value/cost metrics - `get_drug_directory_matrix(db_path, date_filter_id, chart_type)` — Level 3 nodes pivoted as directory × drug with value/cost metrics
- `get_treatment_durations(db_path, date_filter_id, chart_type, directory=None)` — Level 3 nodes with avg_days by drug within a directorate - `get_treatment_durations(db_path, date_filter_id, chart_type, directory=None)` — Level 3 nodes with avg_days by drug within a directorate
- [ ] Add thin wrappers in `dash_app/data/queries.py` for each new function (resolve DB_PATH and delegate) - [x] Add thin wrappers in `dash_app/data/queries.py` for each new function (resolve DB_PATH and delegate)
- **Checkpoint**: All 7 query functions return correct data via manual Python tests (`python -c "..."`) - **Checkpoint**: All 7 query functions return correct data via manual Python tests (`python -c "..."`)
### 9.3 First-Line Market Share chart (Tab 2) ### 9.3 First-Line Market Share chart (Tab 2)
- [ ] Create `dash_app/callbacks/market_share.py`: - [ ] Create `dash_app/callbacks/market_share.py`:
+78
View File
@@ -11,6 +11,13 @@ from typing import Optional
from data_processing.pathway_queries import ( from data_processing.pathway_queries import (
load_initial_data as _load_initial_data, load_initial_data as _load_initial_data,
load_pathway_nodes as _load_pathway_nodes, load_pathway_nodes as _load_pathway_nodes,
get_drug_market_share as _get_drug_market_share,
get_pathway_costs as _get_pathway_costs,
get_cost_waterfall as _get_cost_waterfall,
get_drug_transitions as _get_drug_transitions,
get_dosing_intervals as _get_dosing_intervals,
get_drug_directory_matrix as _get_drug_directory_matrix,
get_treatment_durations as _get_treatment_durations,
) )
DB_PATH = Path(__file__).resolve().parents[2] / "data" / "pathways.db" DB_PATH = Path(__file__).resolve().parents[2] / "data" / "pathways.db"
@@ -37,3 +44,74 @@ def load_pathway_data(
selected_directorates=selected_directorates, selected_directorates=selected_directorates,
selected_trusts=selected_trusts, selected_trusts=selected_trusts,
) )
# --- Analytics chart query wrappers (Phase 9) ---
def get_drug_market_share(
date_filter_id: str = "all_6mo",
chart_type: str = "directory",
directory: Optional[str] = None,
trust: Optional[str] = None,
) -> list[dict]:
"""Level 3 drug nodes grouped by directory with patient counts."""
return _get_drug_market_share(DB_PATH, date_filter_id, chart_type, directory, trust)
def get_pathway_costs(
date_filter_id: str = "all_6mo",
chart_type: str = "directory",
directory: Optional[str] = None,
trust: Optional[str] = None,
) -> list[dict]:
"""Level 4+ pathway nodes with annualized cost."""
return _get_pathway_costs(DB_PATH, date_filter_id, chart_type, directory, trust)
def get_cost_waterfall(
date_filter_id: str = "all_6mo",
chart_type: str = "directory",
trust: Optional[str] = None,
) -> list[dict]:
"""Level 2 directorate nodes with cost per patient."""
return _get_cost_waterfall(DB_PATH, date_filter_id, chart_type, trust)
def get_drug_transitions(
date_filter_id: str = "all_6mo",
chart_type: str = "directory",
directory: Optional[str] = None,
trust: Optional[str] = None,
) -> dict:
"""Drug transition data for Sankey diagram."""
return _get_drug_transitions(DB_PATH, date_filter_id, chart_type, directory, trust)
def get_dosing_intervals(
date_filter_id: str = "all_6mo",
chart_type: str = "directory",
drug: Optional[str] = None,
trust: Optional[str] = None,
) -> list[dict]:
"""Dosing interval data parsed from average_spacing."""
return _get_dosing_intervals(DB_PATH, date_filter_id, chart_type, drug, trust)
def get_drug_directory_matrix(
date_filter_id: str = "all_6mo",
chart_type: str = "directory",
trust: Optional[str] = None,
) -> dict:
"""Directory × drug matrix with patient counts and costs."""
return _get_drug_directory_matrix(DB_PATH, date_filter_id, chart_type, trust)
def get_treatment_durations(
date_filter_id: str = "all_6mo",
chart_type: str = "directory",
directory: Optional[str] = None,
trust: Optional[str] = None,
) -> list[dict]:
"""Treatment duration data (avg_days) by drug."""
return _get_treatment_durations(DB_PATH, date_filter_id, chart_type, directory, trust)
+477
View File
@@ -330,3 +330,480 @@ def _empty_result(error: str = "") -> dict:
"last_updated": "", "last_updated": "",
"error": error, "error": error,
} }
# ---------------------------------------------------------------------------
# Analytics chart query functions (Phase 9)
# ---------------------------------------------------------------------------
def _safe_float(value, default=0.0):
"""Convert a value to float, returning default for None/N/A/empty."""
if value is None or value == "" or value == "N/A":
return default
try:
return float(value)
except (ValueError, TypeError):
return default
def get_drug_market_share(
db_path: Path,
date_filter_id: str,
chart_type: str,
directory: Optional[str] = None,
trust: Optional[str] = None,
) -> list[dict]:
"""Level 3 drug nodes grouped by directory with patient counts and proportions.
Returns list of dicts: [{directory, drug, patients, proportion, cost, cost_pp_pa}]
Sorted by directory total patients desc, then drug patients desc within each.
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level = 3"]
params: list = [date_filter_id, chart_type]
if directory:
where.append("directory = ?")
params.append(directory)
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT labels AS drug, directory, value AS patients,
colour AS proportion, cost, cost_pp_pa, trust_name
FROM pathway_nodes
WHERE {' AND '.join(where)}
ORDER BY directory, value DESC
"""
rows = conn.execute(query, params).fetchall()
# Aggregate across trusts: sum patients/cost per directory+drug
agg = {}
for r in rows:
key = (r["directory"], r["drug"])
if key not in agg:
agg[key] = {"directory": r["directory"], "drug": r["drug"],
"patients": 0, "cost": 0.0}
agg[key]["patients"] += r["patients"] or 0
agg[key]["cost"] += float(r["cost"]) if r["cost"] else 0.0
# Compute proportions within each directory
dir_totals = {}
for v in agg.values():
dir_totals[v["directory"]] = dir_totals.get(v["directory"], 0) + v["patients"]
result = []
for v in agg.values():
total = dir_totals.get(v["directory"], 1)
result.append({
"directory": v["directory"],
"drug": v["drug"],
"patients": v["patients"],
"proportion": round(v["patients"] / total, 4) if total else 0,
"cost": round(v["cost"], 2),
"cost_pp_pa": round(v["cost"] / v["patients"], 2) if v["patients"] else 0,
})
# Sort: directory by total patients desc, drugs by patients desc within
result.sort(key=lambda x: (-dir_totals.get(x["directory"], 0), -x["patients"]))
return result
except sqlite3.Error:
return []
finally:
conn.close()
def get_pathway_costs(
db_path: Path,
date_filter_id: str,
chart_type: str,
directory: Optional[str] = None,
trust: Optional[str] = None,
) -> list[dict]:
"""Level 4+ pathway nodes with annualized cost and pathway labels.
Returns list of dicts: [{ids, pathway_label, cost_pp_pa, patients, directory, drug_sequence}]
Sorted by cost_pp_pa desc.
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level >= 4"]
params: list = [date_filter_id, chart_type]
if directory:
where.append("directory = ?")
params.append(directory)
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT ids, labels, level, value AS patients, cost, costpp,
cost_pp_pa, avg_days, directory, drug_sequence, trust_name
FROM pathway_nodes
WHERE {' AND '.join(where)}
ORDER BY value DESC
"""
rows = conn.execute(query, params).fetchall()
result = []
for r in rows:
cpp = _safe_float(r["cost_pp_pa"])
drugs = (r["drug_sequence"] or "").split("|")
pathway_label = "".join(d for d in drugs if d)
result.append({
"ids": r["ids"],
"pathway_label": pathway_label,
"cost_pp_pa": cpp,
"patients": r["patients"] or 0,
"cost": float(r["cost"]) if r["cost"] else 0.0,
"avg_days": _safe_float(r["avg_days"]),
"directory": r["directory"] or "",
"trust_name": r["trust_name"] or "",
"drug_sequence": drugs,
"level": r["level"],
})
result.sort(key=lambda x: -x["cost_pp_pa"])
return result
except sqlite3.Error:
return []
finally:
conn.close()
def get_cost_waterfall(
db_path: Path,
date_filter_id: str,
chart_type: str,
trust: Optional[str] = None,
) -> list[dict]:
"""Level 2 directorate/indication nodes with cost metrics.
Since level 2 cost_pp_pa is 'N/A', we compute it from child (level 3) nodes:
sum(cost) / sum(patients) for each directory.
Returns list of dicts: [{directory, patients, total_cost, cost_pp}]
Sorted by cost_pp desc.
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level = 3"]
params: list = [date_filter_id, chart_type]
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT directory, SUM(value) AS patients, SUM(cost) AS total_cost
FROM pathway_nodes
WHERE {' AND '.join(where)}
GROUP BY directory
HAVING patients > 0
ORDER BY total_cost DESC
"""
rows = conn.execute(query, params).fetchall()
result = []
for r in rows:
patients = r["patients"] or 0
total_cost = float(r["total_cost"]) if r["total_cost"] else 0.0
result.append({
"directory": r["directory"] or "",
"patients": patients,
"total_cost": round(total_cost, 2),
"cost_pp": round(total_cost / patients, 2) if patients else 0,
})
result.sort(key=lambda x: -x["cost_pp"])
return result
except sqlite3.Error:
return []
finally:
conn.close()
def get_drug_transitions(
db_path: Path,
date_filter_id: str,
chart_type: str,
directory: Optional[str] = None,
trust: Optional[str] = None,
) -> dict:
"""Parse level 3+ nodes into source→target drug transitions for Sankey.
Returns dict with:
nodes: [{name, total_patients}] — unique drug names
links: [{source_idx, target_idx, patients}] — transitions between drugs
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level >= 4"]
params: list = [date_filter_id, chart_type]
if directory:
where.append("directory = ?")
params.append(directory)
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT ids, level, value AS patients, drug_sequence, directory
FROM pathway_nodes
WHERE {' AND '.join(where)}
ORDER BY level, ids
"""
rows = conn.execute(query, params).fetchall()
# Build node list and link aggregation
# Each drug at each treatment line is a separate Sankey node
# e.g., "ADALIMUMAB (1st)" and "ADALIMUMAB (2nd)" are different nodes
drug_line_set = set()
link_agg = {}
for r in rows:
drugs = [d for d in (r["drug_sequence"] or "").split("|") if d]
patients = r["patients"] or 0
if len(drugs) < 2 or patients == 0:
continue
# Only use adjacent transitions (drug[i] → drug[i+1])
for i in range(len(drugs) - 1):
src = f"{drugs[i]} ({_ordinal(i + 1)})"
tgt = f"{drugs[i + 1]} ({_ordinal(i + 2)})"
drug_line_set.add(src)
drug_line_set.add(tgt)
key = (src, tgt)
link_agg[key] = link_agg.get(key, 0) + patients
# Build indexed node list
node_list = sorted(drug_line_set)
node_idx = {name: i for i, name in enumerate(node_list)}
nodes = [{"name": name} for name in node_list]
links = [
{"source_idx": node_idx[src], "target_idx": node_idx[tgt], "patients": pts}
for (src, tgt), pts in sorted(link_agg.items(), key=lambda x: -x[1])
]
return {"nodes": nodes, "links": links}
except sqlite3.Error:
return {"nodes": [], "links": []}
finally:
conn.close()
def _ordinal(n: int) -> str:
"""Return '1st', '2nd', '3rd', '4th', etc."""
if 11 <= n % 100 <= 13:
return f"{n}th"
suffix = {1: "st", 2: "nd", 3: "rd"}.get(n % 10, "th")
return f"{n}{suffix}"
def get_dosing_intervals(
db_path: Path,
date_filter_id: str,
chart_type: str,
drug: Optional[str] = None,
trust: Optional[str] = None,
) -> list[dict]:
"""Level 3 drug nodes with parsed dosing interval data.
Returns list of dicts:
[{drug, trust_name, directory, weekly_interval, dose_count, total_weeks, patients}]
"""
from data_processing.parsing import parse_average_spacing
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level = 3",
"average_spacing IS NOT NULL", "average_spacing != ''"]
params: list = [date_filter_id, chart_type]
if drug:
where.append("labels = ?")
params.append(drug)
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT labels AS drug, trust_name, directory, value AS patients,
average_spacing
FROM pathway_nodes
WHERE {' AND '.join(where)}
ORDER BY labels, trust_name
"""
rows = conn.execute(query, params).fetchall()
result = []
for r in rows:
parsed = parse_average_spacing(r["average_spacing"])
for entry in parsed:
result.append({
"drug": entry["drug_name"],
"trust_name": r["trust_name"] or "",
"directory": r["directory"] or "",
"weekly_interval": entry["weekly_interval"],
"dose_count": entry["dose_count"],
"total_weeks": entry["total_weeks"],
"patients": r["patients"] or 0,
})
return result
except sqlite3.Error:
return []
finally:
conn.close()
def get_drug_directory_matrix(
db_path: Path,
date_filter_id: str,
chart_type: str,
trust: Optional[str] = None,
) -> dict:
"""Level 3 nodes pivoted as directory × drug matrix.
Returns dict with:
directories: sorted list of directory names (rows)
drugs: sorted list of drug names (columns)
matrix: {directory: {drug: {patients, cost, cost_pp_pa}}}
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level = 3"]
params: list = [date_filter_id, chart_type]
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT labels AS drug, directory, value AS patients, cost, cost_pp_pa
FROM pathway_nodes
WHERE {' AND '.join(where)}
ORDER BY directory, labels
"""
rows = conn.execute(query, params).fetchall()
# Aggregate across trusts
matrix = {}
dir_totals = {}
drug_totals = {}
for r in rows:
d = r["directory"] or ""
drug = r["drug"] or ""
patients = r["patients"] or 0
cost = float(r["cost"]) if r["cost"] else 0.0
cpp = _safe_float(r["cost_pp_pa"])
if d not in matrix:
matrix[d] = {}
if drug not in matrix[d]:
matrix[d][drug] = {"patients": 0, "cost": 0.0}
matrix[d][drug]["patients"] += patients
matrix[d][drug]["cost"] += cost
dir_totals[d] = dir_totals.get(d, 0) + patients
drug_totals[drug] = drug_totals.get(drug, 0) + patients
# Add cost_pp_pa to each cell
for d in matrix:
for drug in matrix[d]:
cell = matrix[d][drug]
cell["cost"] = round(cell["cost"], 2)
cell["cost_pp_pa"] = (
round(cell["cost"] / cell["patients"], 2) if cell["patients"] else 0
)
# Sort directories by total patients desc, drugs by frequency desc
directories = sorted(dir_totals, key=lambda x: -dir_totals[x])
drugs = sorted(drug_totals, key=lambda x: -drug_totals[x])
return {"directories": directories, "drugs": drugs, "matrix": matrix}
except sqlite3.Error:
return {"directories": [], "drugs": [], "matrix": {}}
finally:
conn.close()
def get_treatment_durations(
db_path: Path,
date_filter_id: str,
chart_type: str,
directory: Optional[str] = None,
trust: Optional[str] = None,
) -> list[dict]:
"""Level 3 drug nodes with average treatment duration.
Returns list of dicts: [{drug, avg_days, patients, directory}]
Sorted by avg_days desc. Excludes nodes with no avg_days data.
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level = 3",
"avg_days IS NOT NULL"]
params: list = [date_filter_id, chart_type]
if directory:
where.append("directory = ?")
params.append(directory)
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT labels AS drug, directory, trust_name,
value AS patients, avg_days
FROM pathway_nodes
WHERE {' AND '.join(where)}
ORDER BY avg_days DESC
"""
rows = conn.execute(query, params).fetchall()
# Aggregate across trusts: weighted average of avg_days by patients
agg = {}
for r in rows:
key = (r["directory"] or "", r["drug"])
patients = r["patients"] or 0
days = _safe_float(r["avg_days"])
if patients == 0 or days == 0:
continue
if key not in agg:
agg[key] = {"drug": r["drug"], "directory": r["directory"] or "",
"total_weighted_days": 0.0, "total_patients": 0}
agg[key]["total_weighted_days"] += days * patients
agg[key]["total_patients"] += patients
result = []
for v in agg.values():
if v["total_patients"] > 0:
result.append({
"drug": v["drug"],
"directory": v["directory"],
"avg_days": round(v["total_weighted_days"] / v["total_patients"], 1),
"patients": v["total_patients"],
})
result.sort(key=lambda x: -x["avg_days"])
return result
except sqlite3.Error:
return []
finally:
conn.close()