"""Annotated interactive graph visualization for :class:`~monee.model.Network`.
Entry point: :func:`plot_network`.
"""
import networkx as nx
import plotly.graph_objects as go
import monee.model as mm # kept for Network type hint only
from monee.model.core import Intermediate, IntermediateEq, PostProcess, Var
# Import shared theme and layout helpers from result_visualization so the two
# functions always look identical.
from monee.visualization.result_visualization import (
_ACCENT,
_BG,
_BORDER,
_DIM_COLOR,
_FONT,
_FONT_COLOR,
_GRID_LABEL,
_GRID_SYMBOL,
_PANEL,
_TL_GRAY,
_compute_layout,
_fmt,
_grid_type,
_sep,
)
_SKIP_ATTRS: frozenset[str] = frozenset({"active", "independent", "ignored"})
# Pyomo-like solver objects – hide these from hover text (they carry no useful
# design-time information for the user).
_SOLVER_TYPES = (Var, Intermediate, IntermediateEq, PostProcess)
# Marker scaling
def _adaptive_marker_px(graph: nx.Graph, pos: dict, nominal_px: int = 1000) -> float:
"""Marker size (px) scaled to the layout density so symbols don't swallow
their neighbours. Uses the lower-quartile edge length relative to the
layout extent (different layout engines produce wildly different absolute
coordinates, so fixed pixel sizes only suit one of them)."""
if len(pos) < 2:
return 24.0
xs = [p[0] for p in pos.values()]
ys = [p[1] for p in pos.values()]
extent = max(max(xs) - min(xs), max(ys) - min(ys))
if extent <= 0:
return 24.0
edge_lengths = sorted(
((pos[u][0] - pos[v][0]) ** 2 + (pos[u][1] - pos[v][1]) ** 2) ** 0.5
for u, v in graph.edges()
)
if not edge_lengths:
return 24.0
# 10th-percentile edge length: sizes markers to survive the densest
# region (e.g. a meshed electrical cluster) without collapsing everywhere
# else; the true minimum may be a near-coincident node pair, so avoid it.
short_edge = edge_lengths[len(edge_lengths) // 10]
# Markers at most ~65% of a short edge, in pixels of the nominal plot size.
return max(8.0, min(20.0, 0.65 * short_edge / extent * nominal_px))
# Model parameter extraction
def _model_params(model) -> dict:
"""Extract scalar constructor parameters from a model instance.
Skips private attributes and solver-variable objects (Var, Intermediate).
Returns only plain Python scalars (int, float, bool, str).
"""
params: dict = {}
try:
attrs = vars(model)
except TypeError:
return params
for name, val in attrs.items():
if name.startswith("_"):
continue
if isinstance(val, _SOLVER_TYPES):
continue
if callable(val):
continue
if isinstance(val, (int, float, bool, str)):
params[name] = val
return params
# Inline labels shown directly on the graph
def _node_label(int_node) -> str:
"""Short text placed above a node marker."""
model = int_node.model
bkv = getattr(model, "base_kv", None)
if bkv is not None:
try:
return f"{float(bkv):.4g} kV"
except (TypeError, ValueError):
pass
return ""
def _branch_label(int_branch) -> str:
"""Short text placed at a branch midpoint."""
model = int_branch.model
parts: list[str] = []
d = getattr(model, "diameter_m", None)
if d is not None:
try:
mm_val = float(d) * 1000
parts.append(f"⌀{mm_val:.0f}mm")
except (TypeError, ValueError):
pass
length = getattr(model, "length_m", None)
if length is not None:
try:
lm = float(length)
parts.append(f"{lm / 1000:.4g}km" if lm >= 1000 else f"{lm:.4g}m")
except (TypeError, ValueError):
pass
if not parts:
# PowerLine fallback: total resistance
r = getattr(model, "r_ohm_per_m", None)
if r is not None and length is not None:
try:
parts.append(f"{float(r) * float(length):.3g} Ω")
except (TypeError, ValueError):
pass
return " ".join(parts)
# Rich hover text
def _node_hover(int_node, children: list) -> str:
"""HTML hover for a network node, showing all scalar model parameters."""
model = int_node.model
typename = type(model).__name__
nid = getattr(int_node, "id", "?")
nname = getattr(int_node, "name", None)
if nname:
header = (
f"<b>{nname}</b> <span style='color:{_DIM_COLOR}'>{typename} #{nid}</span>"
)
else:
header = f"<b>{typename} #{nid}</b>"
lines = [header, _sep()]
for k, v in _model_params(model).items():
if k in _SKIP_ATTRS:
continue
lines.append(
f"<span style='color:{_DIM_COLOR}'>{k}</span> <b>{_fmt(v)}</b>"
)
if children:
lines.append("<br>" + _sep("children"))
for child in children:
ctype = type(child.model).__name__
cparams = _model_params(child.model)
vals = " ".join(
f"<span style='color:{_DIM_COLOR}'>{k}</span> {_fmt(v)}"
for k, v in cparams.items()
if k not in _SKIP_ATTRS
)
lines.append(f"<i>[{ctype}]</i> {vals}")
return "<br>".join(lines)
def _branch_hover(int_branch, from_id, to_id) -> str:
"""HTML hover for a network branch, showing all scalar model parameters."""
model = int_branch.model
typename = type(model).__name__
bname = getattr(int_branch, "name", None)
if bname:
header = f"<b>{bname}</b> <span style='color:{_DIM_COLOR}'>{typename}</span>"
else:
header = f"<b>{typename}</b>"
lines = [
header,
f"<span style='color:{_DIM_COLOR}'>{from_id} → {to_id}</span>",
_sep(),
]
for k, v in _model_params(model).items():
if k in _SKIP_ATTRS:
continue
lines.append(
f"<span style='color:{_DIM_COLOR}'>{k}</span> <b>{_fmt(v)}</b>"
)
return "<br>".join(lines)
# Main function
[docs]
def plot_network(
network: mm.Network,
title: str | None = None,
show_children: bool = True,
use_monee_positions: bool = False,
write_to: str | None = None,
) -> go.Figure:
"""Plot a :class:`~monee.model.Network` as an annotated interactive graph.
Nodes are colored by energy-carrier type (electricity, heat/water, gas,
coupling point). Branch midpoints show compact parameter labels. Hovering
over any element reveals the full set of scalar model parameters.
Args:
network: The :class:`~monee.model.Network` to visualise.
title: Figure title. Defaults to ``"Network"``.
show_children: Show attached child components (loads, generators, …)
in the parent node's hover tooltip.
use_monee_positions: Use stored ``node.position`` coordinates instead
of the automatic graph layout.
write_to: Optional path to export the figure (PDF / PNG / SVG).
Returns:
A :class:`plotly.graph_objects.Figure`.
"""
graph: nx.Graph = network._network_internal
pos = _compute_layout(graph, network, use_monee_positions)
marker_px = _adaptive_marker_px(graph, pos)
# Node data – collected per grid type
grid_data: dict[str, dict] = {
g: {"x": [], "y": [], "hover": [], "labels": []}
for g in ("power", "water", "gas", "cp")
}
for node_id in graph.nodes:
int_node = graph.nodes[node_id]["internal_node"]
gtype = "cp" if not int_node.independent else _grid_type(int_node.grid)
x, y = pos[node_id]
children = network.childs_by_ids(int_node.child_ids) if show_children else []
d = grid_data[gtype]
d["x"].append(x)
d["y"].append(y)
d["hover"].append(_node_hover(int_node, children))
d["labels"].append(_node_label(int_node))
# Build node traces: soft glow behind each marker
glow_traces: list = []
marker_traces: list = []
for gtype, d in grid_data.items():
if not d["x"]:
continue
color = _ACCENT[gtype]
glow_traces.append(
go.Scatter(
x=d["x"],
y=d["y"],
mode="markers",
hoverinfo="skip",
showlegend=False,
marker={
"symbol": _GRID_SYMBOL[gtype],
"size": 1.5 * marker_px,
"color": color,
"opacity": 0.10,
"line": {"width": 0},
},
)
)
marker_traces.append(
go.Scatter(
x=d["x"],
y=d["y"],
mode="markers+text",
textposition="top center",
text=d["labels"],
textfont={"family": _FONT, "size": 11, "color": _DIM_COLOR},
hovertext=d["hover"],
hoverinfo="text",
name=_GRID_LABEL[gtype],
# The curated legend_entries below carry the legend; without
# this the legend lists every grid type twice.
showlegend=False,
marker={
"symbol": _GRID_SYMBOL[gtype],
"size": marker_px,
"color": color,
"opacity": 0.75,
"line": {"width": max(1.0, marker_px / 12), "color": color},
},
)
)
# Branch traces
# Grouped by (color, is_cp) for efficient rendering; one midpoint trace
# carries per-branch hover text and inline labels.
color_groups: dict[tuple, list] = {}
mid_x: list[float] = []
mid_y: list[float] = []
mid_hover: list[str] = []
mid_label: list[str] = []
mid_colors: list[str] = []
for from_node, to_node, key in graph.edges(keys=True):
int_branch = graph.edges[from_node, to_node, key]["internal_branch"]
is_cp = int_branch.model.is_cp()
if is_cp:
color = _ACCENT["cp"]
else:
int_node_from = graph.nodes[from_node]["internal_node"]
gtype = (
_grid_type(int_node_from.grid) if int_node_from.independent else "cp"
)
color = _ACCENT.get(gtype, _TL_GRAY)
x0, y0 = pos[from_node]
x1, y1 = pos[to_node]
color_groups.setdefault((color, is_cp), []).append((x0, y0, x1, y1))
mid_x.append((x0 + x1) / 2)
mid_y.append((y0 + y1) / 2)
mid_hover.append(_branch_hover(int_branch, from_node, to_node))
mid_label.append(_branch_label(int_branch))
mid_colors.append(color)
edge_traces: list = []
for (color, is_cp), segs in color_groups.items():
x_pts: list = []
y_pts: list = []
for x0, y0, x1, y1 in segs:
x_pts += [x0, x1, None]
y_pts += [y0, y1, None]
edge_traces.append(
go.Scatter(
x=x_pts,
y=y_pts,
mode="lines",
hoverinfo="none",
showlegend=False,
line={
"color": color,
"width": 3.5 if not is_cp else 2,
"dash": "dot" if is_cp else "solid",
},
opacity=0.55,
)
)
midpoint_trace = go.Scatter(
x=mid_x,
y=mid_y,
mode="markers+text",
text=mid_label,
textposition="middle right",
textfont={"family": _FONT, "size": 10, "color": _DIM_COLOR},
hovertext=mid_hover,
hoverinfo="text",
showlegend=False,
marker={
"size": max(4.0, 0.35 * marker_px),
"color": mid_colors,
"symbol": "circle",
"opacity": 0.85,
"line": {"width": 1.5, "color": _BG},
},
)
# Legend
legend_entries: list = []
for gtype, label in _GRID_LABEL.items():
legend_entries.append(
go.Scatter(
x=[None],
y=[None],
mode="markers",
marker={
"size": 11,
"color": _ACCENT[gtype],
"symbol": _GRID_SYMBOL[gtype],
"line": {"width": 2, "color": _ACCENT[gtype]},
},
name=label,
)
)
legend_entries.append(
go.Scatter(
x=[None],
y=[None],
mode="lines",
line={"color": _ACCENT["cp"], "width": 2, "dash": "dot"},
name="Coupling branch (CP)",
)
)
# Assemble – render order: edges → midpoints → glow → markers → legend
all_traces = (
edge_traces + [midpoint_trace] + glow_traces + marker_traces + legend_entries
)
fig = go.Figure(
data=all_traces,
layout=go.Layout(
title={
"text": title or "Network",
"font": {"family": _FONT, "size": 18, "color": _FONT_COLOR},
"x": 0.5,
"xanchor": "center",
"y": 0.97,
},
paper_bgcolor=_BG,
plot_bgcolor=_BG,
hovermode="closest",
hoverlabel={
"bgcolor": _PANEL,
"bordercolor": _BORDER,
"font": {"family": _FONT, "size": 12, "color": _FONT_COLOR},
"namelength": -1,
},
xaxis={
"showgrid": False,
"zeroline": False,
"showticklabels": False,
"showline": False,
},
yaxis={
"showgrid": False,
"zeroline": False,
"showticklabels": False,
"showline": False,
"scaleanchor": "x",
},
font={"family": _FONT, "color": _FONT_COLOR},
autosize=True,
margin={"l": 30, "r": 200, "t": 60, "b": 30},
legend={
"title": {
"text": "Legend",
"font": {"family": _FONT, "size": 12, "color": _DIM_COLOR},
},
"x": 1.02,
"y": 1.0,
"xanchor": "left",
"yanchor": "top",
"bgcolor": "rgba(246, 248, 250, 0.95)",
"bordercolor": _BORDER,
"borderwidth": 1,
"font": {"family": _FONT, "size": 11, "color": _FONT_COLOR},
"itemsizing": "constant",
"tracegroupgap": 6,
},
),
)
if write_to is not None:
# Fixed export size matching the nominal-pixel basis of the adaptive
# marker scaling (interactive autosize stays untouched).
fig.write_image(write_to, width=1230, height=1000)
return fig