d98cd4fd69
New query functions in src/data_processing/pathway_queries.py: - get_drug_market_share: Level 3 drug nodes grouped by directory - get_pathway_costs: Level 4+ pathway nodes with cost_pp_pa - get_cost_waterfall: Directorate cost per patient from level 3 aggregation - get_drug_transitions: Sankey source/target drug transitions with ordinal line labels - get_dosing_intervals: Parsed average_spacing by trust/directory - get_drug_directory_matrix: Directory x drug pivot with patient/cost metrics - get_treatment_durations: Weighted avg_days by drug within directorates Thin wrappers added in dash_app/data/queries.py for all 7 functions.
810 lines
28 KiB
Python
810 lines
28 KiB
Python
"""
|
||
Shared query functions for pathway node data.
|
||
|
||
These functions extract the data loading logic from the Reflex AppState
|
||
into standalone functions that accept db_path as a parameter and return
|
||
plain JSON-serializable dicts/lists. Both Reflex and Dash can call these.
|
||
|
||
All queries are read-only SELECTs against pathways.db.
|
||
"""
|
||
|
||
import sqlite3
|
||
from pathlib import Path
|
||
from typing import Optional
|
||
|
||
|
||
def load_initial_data(db_path: Path) -> dict:
|
||
"""
|
||
Load reference data from SQLite on app initialization.
|
||
|
||
Extracted from AppState.load_data() (pathways_app.py lines 407-488).
|
||
|
||
Returns dict with keys:
|
||
available_drugs: sorted list of unique drug labels (level 3 nodes)
|
||
available_directorates: sorted list of unique directorate labels (level 2, directory charts)
|
||
available_indications: sorted list of unique indications from ref_drug_indication_clusters
|
||
total_records: source row count from latest completed refresh
|
||
last_updated: ISO timestamp of latest completed refresh
|
||
"""
|
||
if not db_path.exists():
|
||
return {
|
||
"available_drugs": [],
|
||
"available_directorates": [],
|
||
"available_indications": [],
|
||
"available_trusts": [],
|
||
"total_records": 0,
|
||
"last_updated": "",
|
||
"error": "Database not found",
|
||
}
|
||
|
||
conn = sqlite3.connect(str(db_path))
|
||
try:
|
||
cursor = conn.cursor()
|
||
|
||
# Latest completed refresh metadata
|
||
cursor.execute("""
|
||
SELECT source_row_count, completed_at
|
||
FROM pathway_refresh_log
|
||
WHERE status = 'completed'
|
||
ORDER BY started_at DESC
|
||
LIMIT 1
|
||
""")
|
||
refresh_row = cursor.fetchone()
|
||
total_records = (refresh_row[0] or 0) if refresh_row else 0
|
||
last_updated = (refresh_row[1] or "") if refresh_row else ""
|
||
|
||
# Unique drugs from pathway_nodes level 3
|
||
cursor.execute("""
|
||
SELECT DISTINCT labels
|
||
FROM pathway_nodes
|
||
WHERE level = 3 AND labels IS NOT NULL AND labels != ''
|
||
ORDER BY labels
|
||
""")
|
||
available_drugs = [row[0] for row in cursor.fetchall()]
|
||
|
||
# Unique directorates from directory chart pathway_nodes level 2
|
||
cursor.execute("""
|
||
SELECT DISTINCT labels
|
||
FROM pathway_nodes
|
||
WHERE level = 2 AND chart_type = 'directory'
|
||
AND labels IS NOT NULL AND labels != ''
|
||
ORDER BY labels
|
||
""")
|
||
available_directorates = [row[0] for row in cursor.fetchall()]
|
||
|
||
# Unique indications from ref_drug_indication_clusters
|
||
cursor.execute("""
|
||
SELECT DISTINCT indication
|
||
FROM ref_drug_indication_clusters
|
||
WHERE indication IS NOT NULL AND indication != ''
|
||
ORDER BY indication
|
||
""")
|
||
available_indications = [row[0] for row in cursor.fetchall()]
|
||
if not available_indications:
|
||
available_indications = ["(No indications available)"]
|
||
|
||
# Unique trusts from pathway_nodes level 1
|
||
cursor.execute("""
|
||
SELECT DISTINCT trust_name
|
||
FROM pathway_nodes
|
||
WHERE level = 1 AND trust_name IS NOT NULL AND trust_name != ''
|
||
ORDER BY trust_name
|
||
""")
|
||
available_trusts = [row[0] for row in cursor.fetchall()]
|
||
|
||
# Total patients from default root node (fallback when source_row_count is empty)
|
||
total_patients = 0
|
||
if not total_records:
|
||
cursor.execute("""
|
||
SELECT value FROM pathway_nodes
|
||
WHERE level = 0 AND date_filter_id = 'all_6mo' AND chart_type = 'directory'
|
||
LIMIT 1
|
||
""")
|
||
root_row = cursor.fetchone()
|
||
total_patients = (root_row[0] or 0) if root_row else 0
|
||
|
||
return {
|
||
"available_drugs": available_drugs,
|
||
"available_directorates": available_directorates,
|
||
"available_indications": available_indications,
|
||
"available_trusts": available_trusts,
|
||
"total_records": total_records,
|
||
"total_patients": total_patients,
|
||
"last_updated": last_updated,
|
||
}
|
||
except sqlite3.Error as e:
|
||
return {
|
||
"available_drugs": [],
|
||
"available_directorates": [],
|
||
"available_indications": [],
|
||
"available_trusts": [],
|
||
"total_records": 0,
|
||
"last_updated": "",
|
||
"error": f"Database error: {e}",
|
||
}
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
def load_pathway_nodes(
|
||
db_path: Path,
|
||
filter_id: str,
|
||
chart_type: str,
|
||
selected_drugs: Optional[list[str]] = None,
|
||
selected_directorates: Optional[list[str]] = None,
|
||
selected_trusts: Optional[list[str]] = None,
|
||
) -> dict:
|
||
"""
|
||
Load pre-computed pathway nodes from SQLite.
|
||
|
||
Extracted from AppState.load_pathway_data() (pathways_app.py lines 490-642).
|
||
|
||
Args:
|
||
db_path: Path to pathways.db
|
||
filter_id: e.g. "all_6mo", "2yr_12mo"
|
||
chart_type: "directory" or "indication"
|
||
selected_drugs: optional list of drug names to filter by
|
||
selected_directorates: optional list of directorate names to filter by
|
||
selected_trusts: optional list of trust names to filter by
|
||
|
||
Returns dict with keys:
|
||
nodes: list of dicts (JSON-serializable) with chart node data
|
||
unique_patients: int (from root node)
|
||
total_drugs: int (unique drugs across level 3+ nodes)
|
||
total_cost: float (from root node)
|
||
last_updated: ISO timestamp string
|
||
error: optional error string
|
||
"""
|
||
if not db_path.exists():
|
||
return _empty_result("Database not found")
|
||
|
||
conn = sqlite3.connect(str(db_path))
|
||
conn.row_factory = sqlite3.Row
|
||
try:
|
||
cursor = conn.cursor()
|
||
|
||
# Build WHERE clause with parameterized values
|
||
where_clauses = ["date_filter_id = ?", "chart_type = ?"]
|
||
params: list = [filter_id, chart_type]
|
||
|
||
if selected_directorates:
|
||
placeholders = ",".join("?" * len(selected_directorates))
|
||
where_clauses.append(
|
||
f"(level < 2 OR directory IN ({placeholders}) OR directory IS NULL OR directory = '')"
|
||
)
|
||
params.extend(selected_directorates)
|
||
|
||
if selected_drugs:
|
||
drug_conditions = []
|
||
for drug in selected_drugs:
|
||
drug_conditions.append("drug_sequence LIKE ?")
|
||
params.append(f"%{drug}%")
|
||
where_clauses.append(
|
||
f"(level < 3 OR {' OR '.join(drug_conditions)})"
|
||
)
|
||
|
||
if selected_trusts:
|
||
placeholders = ",".join("?" * len(selected_trusts))
|
||
where_clauses.append(
|
||
f"(trust_name IN ({placeholders}) OR trust_name IS NULL OR trust_name = '')"
|
||
)
|
||
params.extend(selected_trusts)
|
||
|
||
where_clause = " AND ".join(where_clauses)
|
||
|
||
query = f"""
|
||
SELECT
|
||
parents, ids, labels, level, value,
|
||
cost, costpp, cost_pp_pa, colour,
|
||
first_seen, last_seen, first_seen_parent, last_seen_parent,
|
||
average_spacing, average_administered, avg_days,
|
||
trust_name, directory, drug_sequence
|
||
FROM pathway_nodes
|
||
WHERE {where_clause}
|
||
ORDER BY level, parents, ids
|
||
"""
|
||
|
||
cursor.execute(query, params)
|
||
rows = cursor.fetchall()
|
||
|
||
if not rows:
|
||
return _empty_result(f"No pathway data for filter: {filter_id}")
|
||
|
||
# When drug or directorate filters are active, prune ancestor nodes
|
||
# that have no matching descendants. Without this, the icicle chart
|
||
# shows empty directorate/trust boxes with no children.
|
||
if selected_drugs or selected_directorates:
|
||
rows = _prune_empty_ancestors(rows)
|
||
|
||
nodes = []
|
||
root_patients = 0
|
||
root_cost = 0.0
|
||
has_entity_filter = bool(selected_drugs or selected_directorates or selected_trusts)
|
||
|
||
for row in rows:
|
||
node = {
|
||
"parents": row["parents"] or "",
|
||
"ids": row["ids"] or "",
|
||
"labels": row["labels"] or "",
|
||
"value": row["value"] or 0,
|
||
"cost": float(row["cost"]) if row["cost"] else 0.0,
|
||
"costpp": float(row["costpp"]) if row["costpp"] else 0.0,
|
||
"colour": float(row["colour"]) if row["colour"] else 0.0,
|
||
"first_seen": row["first_seen"] or "",
|
||
"last_seen": row["last_seen"] or "",
|
||
"first_seen_parent": row["first_seen_parent"] or "",
|
||
"last_seen_parent": row["last_seen_parent"] or "",
|
||
"average_spacing": row["average_spacing"] or "",
|
||
"cost_pp_pa": row["cost_pp_pa"] or "",
|
||
}
|
||
nodes.append(node)
|
||
|
||
if row["level"] == 0:
|
||
root_patients = row["value"] or 0
|
||
root_cost = float(row["cost"]) if row["cost"] else 0.0
|
||
|
||
# Count unique drugs from level 3+ nodes
|
||
unique_drugs = set()
|
||
for row in rows:
|
||
if row["level"] >= 3 and row["drug_sequence"]:
|
||
for drug in row["drug_sequence"].split("|"):
|
||
if drug:
|
||
unique_drugs.add(drug)
|
||
|
||
# When entity filters are active, sum level-3 drug nodes for KPIs
|
||
# instead of using the root node's pre-computed (unfiltered) totals
|
||
if has_entity_filter:
|
||
filtered_patients = sum(
|
||
row["value"] or 0 for row in rows if row["level"] == 3
|
||
)
|
||
filtered_cost = sum(
|
||
float(row["cost"]) if row["cost"] else 0.0
|
||
for row in rows if row["level"] == 3
|
||
)
|
||
root_patients = filtered_patients
|
||
root_cost = filtered_cost
|
||
|
||
# Data freshness
|
||
cursor.execute("""
|
||
SELECT completed_at
|
||
FROM pathway_refresh_log
|
||
WHERE status = 'completed'
|
||
ORDER BY completed_at DESC
|
||
LIMIT 1
|
||
""")
|
||
refresh_row = cursor.fetchone()
|
||
last_updated = (
|
||
refresh_row["completed_at"] if refresh_row and refresh_row["completed_at"] else ""
|
||
)
|
||
|
||
return {
|
||
"nodes": nodes,
|
||
"unique_patients": root_patients,
|
||
"total_drugs": len(unique_drugs),
|
||
"total_cost": root_cost,
|
||
"last_updated": last_updated,
|
||
}
|
||
except sqlite3.Error as e:
|
||
return _empty_result(f"Database error: {e}")
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
def _prune_empty_ancestors(rows):
|
||
"""Remove ancestor nodes that have no matching descendants.
|
||
|
||
When drug/directorate filters are active, levels 0-2 are included
|
||
unconditionally. This leaves directorate and trust nodes that have no
|
||
children in the filtered result. Plotly's icicle chart shows these as
|
||
empty boxes. Prune them by keeping only nodes whose ids appear as a
|
||
parent of another kept node, or that are leaf nodes (level 3+), or
|
||
are the root (level 0).
|
||
"""
|
||
# Collect all parent references from the result set
|
||
referenced_parents = {row["parents"] for row in rows if row["parents"]}
|
||
# Keep: root (level 0), any node referenced as a parent, any leaf (level 3+)
|
||
kept = [
|
||
row for row in rows
|
||
if row["level"] == 0
|
||
or row["level"] >= 3
|
||
or row["ids"] in referenced_parents
|
||
]
|
||
# Second pass: a trust (level 1) may reference root but itself have no
|
||
# kept level-2 children. Recheck that level-1 nodes are still parents
|
||
# of something in the kept set.
|
||
kept_parents = {row["parents"] for row in kept if row["parents"]}
|
||
return [
|
||
row for row in kept
|
||
if row["level"] == 0
|
||
or row["level"] >= 3
|
||
or row["ids"] in kept_parents
|
||
]
|
||
|
||
|
||
def _empty_result(error: str = "") -> dict:
|
||
return {
|
||
"nodes": [],
|
||
"unique_patients": 0,
|
||
"total_drugs": 0,
|
||
"total_cost": 0.0,
|
||
"last_updated": "",
|
||
"error": error,
|
||
}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Analytics chart query functions (Phase 9)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _safe_float(value, default=0.0):
|
||
"""Convert a value to float, returning default for None/N/A/empty."""
|
||
if value is None or value == "" or value == "N/A":
|
||
return default
|
||
try:
|
||
return float(value)
|
||
except (ValueError, TypeError):
|
||
return default
|
||
|
||
|
||
def get_drug_market_share(
|
||
db_path: Path,
|
||
date_filter_id: str,
|
||
chart_type: str,
|
||
directory: Optional[str] = None,
|
||
trust: Optional[str] = None,
|
||
) -> list[dict]:
|
||
"""Level 3 drug nodes grouped by directory with patient counts and proportions.
|
||
|
||
Returns list of dicts: [{directory, drug, patients, proportion, cost, cost_pp_pa}]
|
||
Sorted by directory total patients desc, then drug patients desc within each.
|
||
"""
|
||
conn = sqlite3.connect(str(db_path))
|
||
conn.row_factory = sqlite3.Row
|
||
try:
|
||
where = ["date_filter_id = ?", "chart_type = ?", "level = 3"]
|
||
params: list = [date_filter_id, chart_type]
|
||
|
||
if directory:
|
||
where.append("directory = ?")
|
||
params.append(directory)
|
||
if trust:
|
||
where.append("trust_name = ?")
|
||
params.append(trust)
|
||
|
||
query = f"""
|
||
SELECT labels AS drug, directory, value AS patients,
|
||
colour AS proportion, cost, cost_pp_pa, trust_name
|
||
FROM pathway_nodes
|
||
WHERE {' AND '.join(where)}
|
||
ORDER BY directory, value DESC
|
||
"""
|
||
rows = conn.execute(query, params).fetchall()
|
||
|
||
# Aggregate across trusts: sum patients/cost per directory+drug
|
||
agg = {}
|
||
for r in rows:
|
||
key = (r["directory"], r["drug"])
|
||
if key not in agg:
|
||
agg[key] = {"directory": r["directory"], "drug": r["drug"],
|
||
"patients": 0, "cost": 0.0}
|
||
agg[key]["patients"] += r["patients"] or 0
|
||
agg[key]["cost"] += float(r["cost"]) if r["cost"] else 0.0
|
||
|
||
# Compute proportions within each directory
|
||
dir_totals = {}
|
||
for v in agg.values():
|
||
dir_totals[v["directory"]] = dir_totals.get(v["directory"], 0) + v["patients"]
|
||
|
||
result = []
|
||
for v in agg.values():
|
||
total = dir_totals.get(v["directory"], 1)
|
||
result.append({
|
||
"directory": v["directory"],
|
||
"drug": v["drug"],
|
||
"patients": v["patients"],
|
||
"proportion": round(v["patients"] / total, 4) if total else 0,
|
||
"cost": round(v["cost"], 2),
|
||
"cost_pp_pa": round(v["cost"] / v["patients"], 2) if v["patients"] else 0,
|
||
})
|
||
|
||
# Sort: directory by total patients desc, drugs by patients desc within
|
||
result.sort(key=lambda x: (-dir_totals.get(x["directory"], 0), -x["patients"]))
|
||
return result
|
||
except sqlite3.Error:
|
||
return []
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
def get_pathway_costs(
|
||
db_path: Path,
|
||
date_filter_id: str,
|
||
chart_type: str,
|
||
directory: Optional[str] = None,
|
||
trust: Optional[str] = None,
|
||
) -> list[dict]:
|
||
"""Level 4+ pathway nodes with annualized cost and pathway labels.
|
||
|
||
Returns list of dicts: [{ids, pathway_label, cost_pp_pa, patients, directory, drug_sequence}]
|
||
Sorted by cost_pp_pa desc.
|
||
"""
|
||
conn = sqlite3.connect(str(db_path))
|
||
conn.row_factory = sqlite3.Row
|
||
try:
|
||
where = ["date_filter_id = ?", "chart_type = ?", "level >= 4"]
|
||
params: list = [date_filter_id, chart_type]
|
||
|
||
if directory:
|
||
where.append("directory = ?")
|
||
params.append(directory)
|
||
if trust:
|
||
where.append("trust_name = ?")
|
||
params.append(trust)
|
||
|
||
query = f"""
|
||
SELECT ids, labels, level, value AS patients, cost, costpp,
|
||
cost_pp_pa, avg_days, directory, drug_sequence, trust_name
|
||
FROM pathway_nodes
|
||
WHERE {' AND '.join(where)}
|
||
ORDER BY value DESC
|
||
"""
|
||
rows = conn.execute(query, params).fetchall()
|
||
|
||
result = []
|
||
for r in rows:
|
||
cpp = _safe_float(r["cost_pp_pa"])
|
||
drugs = (r["drug_sequence"] or "").split("|")
|
||
pathway_label = " → ".join(d for d in drugs if d)
|
||
result.append({
|
||
"ids": r["ids"],
|
||
"pathway_label": pathway_label,
|
||
"cost_pp_pa": cpp,
|
||
"patients": r["patients"] or 0,
|
||
"cost": float(r["cost"]) if r["cost"] else 0.0,
|
||
"avg_days": _safe_float(r["avg_days"]),
|
||
"directory": r["directory"] or "",
|
||
"trust_name": r["trust_name"] or "",
|
||
"drug_sequence": drugs,
|
||
"level": r["level"],
|
||
})
|
||
|
||
result.sort(key=lambda x: -x["cost_pp_pa"])
|
||
return result
|
||
except sqlite3.Error:
|
||
return []
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
def get_cost_waterfall(
|
||
db_path: Path,
|
||
date_filter_id: str,
|
||
chart_type: str,
|
||
trust: Optional[str] = None,
|
||
) -> list[dict]:
|
||
"""Level 2 directorate/indication nodes with cost metrics.
|
||
|
||
Since level 2 cost_pp_pa is 'N/A', we compute it from child (level 3) nodes:
|
||
sum(cost) / sum(patients) for each directory.
|
||
|
||
Returns list of dicts: [{directory, patients, total_cost, cost_pp}]
|
||
Sorted by cost_pp desc.
|
||
"""
|
||
conn = sqlite3.connect(str(db_path))
|
||
conn.row_factory = sqlite3.Row
|
||
try:
|
||
where = ["date_filter_id = ?", "chart_type = ?", "level = 3"]
|
||
params: list = [date_filter_id, chart_type]
|
||
|
||
if trust:
|
||
where.append("trust_name = ?")
|
||
params.append(trust)
|
||
|
||
query = f"""
|
||
SELECT directory, SUM(value) AS patients, SUM(cost) AS total_cost
|
||
FROM pathway_nodes
|
||
WHERE {' AND '.join(where)}
|
||
GROUP BY directory
|
||
HAVING patients > 0
|
||
ORDER BY total_cost DESC
|
||
"""
|
||
rows = conn.execute(query, params).fetchall()
|
||
|
||
result = []
|
||
for r in rows:
|
||
patients = r["patients"] or 0
|
||
total_cost = float(r["total_cost"]) if r["total_cost"] else 0.0
|
||
result.append({
|
||
"directory": r["directory"] or "",
|
||
"patients": patients,
|
||
"total_cost": round(total_cost, 2),
|
||
"cost_pp": round(total_cost / patients, 2) if patients else 0,
|
||
})
|
||
|
||
result.sort(key=lambda x: -x["cost_pp"])
|
||
return result
|
||
except sqlite3.Error:
|
||
return []
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
def get_drug_transitions(
|
||
db_path: Path,
|
||
date_filter_id: str,
|
||
chart_type: str,
|
||
directory: Optional[str] = None,
|
||
trust: Optional[str] = None,
|
||
) -> dict:
|
||
"""Parse level 3+ nodes into source→target drug transitions for Sankey.
|
||
|
||
Returns dict with:
|
||
nodes: [{name, total_patients}] — unique drug names
|
||
links: [{source_idx, target_idx, patients}] — transitions between drugs
|
||
"""
|
||
conn = sqlite3.connect(str(db_path))
|
||
conn.row_factory = sqlite3.Row
|
||
try:
|
||
where = ["date_filter_id = ?", "chart_type = ?", "level >= 4"]
|
||
params: list = [date_filter_id, chart_type]
|
||
|
||
if directory:
|
||
where.append("directory = ?")
|
||
params.append(directory)
|
||
if trust:
|
||
where.append("trust_name = ?")
|
||
params.append(trust)
|
||
|
||
query = f"""
|
||
SELECT ids, level, value AS patients, drug_sequence, directory
|
||
FROM pathway_nodes
|
||
WHERE {' AND '.join(where)}
|
||
ORDER BY level, ids
|
||
"""
|
||
rows = conn.execute(query, params).fetchall()
|
||
|
||
# Build node list and link aggregation
|
||
# Each drug at each treatment line is a separate Sankey node
|
||
# e.g., "ADALIMUMAB (1st)" and "ADALIMUMAB (2nd)" are different nodes
|
||
drug_line_set = set()
|
||
link_agg = {}
|
||
|
||
for r in rows:
|
||
drugs = [d for d in (r["drug_sequence"] or "").split("|") if d]
|
||
patients = r["patients"] or 0
|
||
if len(drugs) < 2 or patients == 0:
|
||
continue
|
||
|
||
# Only use adjacent transitions (drug[i] → drug[i+1])
|
||
for i in range(len(drugs) - 1):
|
||
src = f"{drugs[i]} ({_ordinal(i + 1)})"
|
||
tgt = f"{drugs[i + 1]} ({_ordinal(i + 2)})"
|
||
drug_line_set.add(src)
|
||
drug_line_set.add(tgt)
|
||
key = (src, tgt)
|
||
link_agg[key] = link_agg.get(key, 0) + patients
|
||
|
||
# Build indexed node list
|
||
node_list = sorted(drug_line_set)
|
||
node_idx = {name: i for i, name in enumerate(node_list)}
|
||
|
||
nodes = [{"name": name} for name in node_list]
|
||
links = [
|
||
{"source_idx": node_idx[src], "target_idx": node_idx[tgt], "patients": pts}
|
||
for (src, tgt), pts in sorted(link_agg.items(), key=lambda x: -x[1])
|
||
]
|
||
|
||
return {"nodes": nodes, "links": links}
|
||
except sqlite3.Error:
|
||
return {"nodes": [], "links": []}
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
def _ordinal(n: int) -> str:
|
||
"""Return '1st', '2nd', '3rd', '4th', etc."""
|
||
if 11 <= n % 100 <= 13:
|
||
return f"{n}th"
|
||
suffix = {1: "st", 2: "nd", 3: "rd"}.get(n % 10, "th")
|
||
return f"{n}{suffix}"
|
||
|
||
|
||
def get_dosing_intervals(
|
||
db_path: Path,
|
||
date_filter_id: str,
|
||
chart_type: str,
|
||
drug: Optional[str] = None,
|
||
trust: Optional[str] = None,
|
||
) -> list[dict]:
|
||
"""Level 3 drug nodes with parsed dosing interval data.
|
||
|
||
Returns list of dicts:
|
||
[{drug, trust_name, directory, weekly_interval, dose_count, total_weeks, patients}]
|
||
"""
|
||
from data_processing.parsing import parse_average_spacing
|
||
|
||
conn = sqlite3.connect(str(db_path))
|
||
conn.row_factory = sqlite3.Row
|
||
try:
|
||
where = ["date_filter_id = ?", "chart_type = ?", "level = 3",
|
||
"average_spacing IS NOT NULL", "average_spacing != ''"]
|
||
params: list = [date_filter_id, chart_type]
|
||
|
||
if drug:
|
||
where.append("labels = ?")
|
||
params.append(drug)
|
||
if trust:
|
||
where.append("trust_name = ?")
|
||
params.append(trust)
|
||
|
||
query = f"""
|
||
SELECT labels AS drug, trust_name, directory, value AS patients,
|
||
average_spacing
|
||
FROM pathway_nodes
|
||
WHERE {' AND '.join(where)}
|
||
ORDER BY labels, trust_name
|
||
"""
|
||
rows = conn.execute(query, params).fetchall()
|
||
|
||
result = []
|
||
for r in rows:
|
||
parsed = parse_average_spacing(r["average_spacing"])
|
||
for entry in parsed:
|
||
result.append({
|
||
"drug": entry["drug_name"],
|
||
"trust_name": r["trust_name"] or "",
|
||
"directory": r["directory"] or "",
|
||
"weekly_interval": entry["weekly_interval"],
|
||
"dose_count": entry["dose_count"],
|
||
"total_weeks": entry["total_weeks"],
|
||
"patients": r["patients"] or 0,
|
||
})
|
||
|
||
return result
|
||
except sqlite3.Error:
|
||
return []
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
def get_drug_directory_matrix(
|
||
db_path: Path,
|
||
date_filter_id: str,
|
||
chart_type: str,
|
||
trust: Optional[str] = None,
|
||
) -> dict:
|
||
"""Level 3 nodes pivoted as directory × drug matrix.
|
||
|
||
Returns dict with:
|
||
directories: sorted list of directory names (rows)
|
||
drugs: sorted list of drug names (columns)
|
||
matrix: {directory: {drug: {patients, cost, cost_pp_pa}}}
|
||
"""
|
||
conn = sqlite3.connect(str(db_path))
|
||
conn.row_factory = sqlite3.Row
|
||
try:
|
||
where = ["date_filter_id = ?", "chart_type = ?", "level = 3"]
|
||
params: list = [date_filter_id, chart_type]
|
||
|
||
if trust:
|
||
where.append("trust_name = ?")
|
||
params.append(trust)
|
||
|
||
query = f"""
|
||
SELECT labels AS drug, directory, value AS patients, cost, cost_pp_pa
|
||
FROM pathway_nodes
|
||
WHERE {' AND '.join(where)}
|
||
ORDER BY directory, labels
|
||
"""
|
||
rows = conn.execute(query, params).fetchall()
|
||
|
||
# Aggregate across trusts
|
||
matrix = {}
|
||
dir_totals = {}
|
||
drug_totals = {}
|
||
|
||
for r in rows:
|
||
d = r["directory"] or ""
|
||
drug = r["drug"] or ""
|
||
patients = r["patients"] or 0
|
||
cost = float(r["cost"]) if r["cost"] else 0.0
|
||
cpp = _safe_float(r["cost_pp_pa"])
|
||
|
||
if d not in matrix:
|
||
matrix[d] = {}
|
||
if drug not in matrix[d]:
|
||
matrix[d][drug] = {"patients": 0, "cost": 0.0}
|
||
|
||
matrix[d][drug]["patients"] += patients
|
||
matrix[d][drug]["cost"] += cost
|
||
|
||
dir_totals[d] = dir_totals.get(d, 0) + patients
|
||
drug_totals[drug] = drug_totals.get(drug, 0) + patients
|
||
|
||
# Add cost_pp_pa to each cell
|
||
for d in matrix:
|
||
for drug in matrix[d]:
|
||
cell = matrix[d][drug]
|
||
cell["cost"] = round(cell["cost"], 2)
|
||
cell["cost_pp_pa"] = (
|
||
round(cell["cost"] / cell["patients"], 2) if cell["patients"] else 0
|
||
)
|
||
|
||
# Sort directories by total patients desc, drugs by frequency desc
|
||
directories = sorted(dir_totals, key=lambda x: -dir_totals[x])
|
||
drugs = sorted(drug_totals, key=lambda x: -drug_totals[x])
|
||
|
||
return {"directories": directories, "drugs": drugs, "matrix": matrix}
|
||
except sqlite3.Error:
|
||
return {"directories": [], "drugs": [], "matrix": {}}
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
def get_treatment_durations(
|
||
db_path: Path,
|
||
date_filter_id: str,
|
||
chart_type: str,
|
||
directory: Optional[str] = None,
|
||
trust: Optional[str] = None,
|
||
) -> list[dict]:
|
||
"""Level 3 drug nodes with average treatment duration.
|
||
|
||
Returns list of dicts: [{drug, avg_days, patients, directory}]
|
||
Sorted by avg_days desc. Excludes nodes with no avg_days data.
|
||
"""
|
||
conn = sqlite3.connect(str(db_path))
|
||
conn.row_factory = sqlite3.Row
|
||
try:
|
||
where = ["date_filter_id = ?", "chart_type = ?", "level = 3",
|
||
"avg_days IS NOT NULL"]
|
||
params: list = [date_filter_id, chart_type]
|
||
|
||
if directory:
|
||
where.append("directory = ?")
|
||
params.append(directory)
|
||
if trust:
|
||
where.append("trust_name = ?")
|
||
params.append(trust)
|
||
|
||
query = f"""
|
||
SELECT labels AS drug, directory, trust_name,
|
||
value AS patients, avg_days
|
||
FROM pathway_nodes
|
||
WHERE {' AND '.join(where)}
|
||
ORDER BY avg_days DESC
|
||
"""
|
||
rows = conn.execute(query, params).fetchall()
|
||
|
||
# Aggregate across trusts: weighted average of avg_days by patients
|
||
agg = {}
|
||
for r in rows:
|
||
key = (r["directory"] or "", r["drug"])
|
||
patients = r["patients"] or 0
|
||
days = _safe_float(r["avg_days"])
|
||
if patients == 0 or days == 0:
|
||
continue
|
||
|
||
if key not in agg:
|
||
agg[key] = {"drug": r["drug"], "directory": r["directory"] or "",
|
||
"total_weighted_days": 0.0, "total_patients": 0}
|
||
agg[key]["total_weighted_days"] += days * patients
|
||
agg[key]["total_patients"] += patients
|
||
|
||
result = []
|
||
for v in agg.values():
|
||
if v["total_patients"] > 0:
|
||
result.append({
|
||
"drug": v["drug"],
|
||
"directory": v["directory"],
|
||
"avg_days": round(v["total_weighted_days"] / v["total_patients"], 1),
|
||
"patients": v["total_patients"],
|
||
})
|
||
|
||
result.sort(key=lambda x: -x["avg_days"])
|
||
return result
|
||
except sqlite3.Error:
|
||
return []
|
||
finally:
|
||
conn.close()
|