11import contextlib
22from collections .abc import Sequence
3+ from dataclasses import dataclass
34from typing import Iterable
45
56import rustworkx as rx
@@ -16,12 +17,11 @@ class TopologySnapshot(BaseModel):
1617 model_config = ConfigDict (frozen = True , extra = "forbid" )
1718
1819
20+ @dataclass
1921class Topology :
20- def __init__ (self ) -> None :
21- self ._graph : rx .PyDiGraph [NodeId , SocketConnection | RDMAConnection ] = (
22- rx .PyDiGraph ()
23- )
24- self ._node_id_to_rx_id_map : dict [NodeId , int ] = dict ()
22+ # the _graph can be used as a int -> NodeId map.
23+ _graph = rx .PyDiGraph [NodeId , SocketConnection | RDMAConnection ]()
24+ _vertex_indices = dict [NodeId , int ]()
2525
2626 def to_snapshot (self ) -> TopologySnapshot :
2727 return TopologySnapshot (
@@ -43,60 +43,68 @@ def from_snapshot(cls, snapshot: TopologySnapshot) -> "Topology":
4343 return topology
4444
4545 def add_node (self , node : NodeId ) -> None :
46- if node in self ._node_id_to_rx_id_map :
46+ if node in self ._vertex_indices :
4747 return
4848 rx_id = self ._graph .add_node (node )
49- self ._node_id_to_rx_id_map [node ] = rx_id
49+ self ._vertex_indices [node ] = rx_id
5050 self ._graph [rx_id ] = node
5151
5252 def node_is_leaf (self , node_id : NodeId ) -> bool :
5353 return (
54- node_id in self ._node_id_to_rx_id_map
55- and len (self ._graph .neighbors (self ._node_id_to_rx_id_map [node_id ])) <= 1
54+ node_id in self ._vertex_indices
55+ and len (self ._graph .neighbors (self ._vertex_indices [node_id ])) <= 1
5656 )
5757
5858 def neighbours (self , node_id : NodeId ) -> list [NodeId ]:
5959 return [
6060 self ._graph [rx_id ]
61- for rx_id in self ._graph .neighbors (self ._node_id_to_rx_id_map [node_id ])
61+ for rx_id in self ._graph .neighbors (self ._vertex_indices [node_id ])
6262 ]
6363
6464 def out_edges (
6565 self , node_id : NodeId
6666 ) -> Iterable [tuple [NodeId , SocketConnection | RDMAConnection ]]:
67- if node_id not in self ._node_id_to_rx_id_map :
67+ if node_id not in self ._vertex_indices :
6868 return []
6969 return (
7070 (self ._graph [nid ], conn )
7171 for _ , nid , conn in self ._graph .out_edges (
72- self ._node_id_to_rx_id_map [node_id ]
72+ self ._vertex_indices [node_id ]
7373 )
7474 )
7575
7676 def contains_node (self , node_id : NodeId ) -> bool :
77- return node_id in self ._node_id_to_rx_id_map
77+ return node_id in self ._vertex_indices
7878
7979 def add_connection (
8080 self ,
8181 source : NodeId ,
8282 sink : NodeId ,
8383 connection : SocketConnection | RDMAConnection ,
8484 ) -> None :
85- if source not in self ._node_id_to_rx_id_map :
85+ if connection in self .get_all_connections_between (source , sink ):
86+ return
87+
88+ if source not in self ._vertex_indices :
8689 self .add_node (source )
87- if sink not in self ._node_id_to_rx_id_map :
90+ if sink not in self ._vertex_indices :
8891 self .add_node (sink )
8992
90- src_id = self ._node_id_to_rx_id_map [source ]
91- sink_id = self ._node_id_to_rx_id_map [sink ]
93+ src_id = self ._vertex_indices [source ]
94+ sink_id = self ._vertex_indices [sink ]
9295
93- self ._graph .add_edge (src_id , sink_id , connection )
96+ _ = self ._graph .add_edge (src_id , sink_id , connection )
9497
9598 def get_all_connections_between (
9699 self , source : NodeId , sink : NodeId
97100 ) -> Iterable [SocketConnection | RDMAConnection ]:
98- src_id = self ._node_id_to_rx_id_map [source ]
99- sink_id = self ._node_id_to_rx_id_map [sink ]
101+ if source not in self ._vertex_indices :
102+ return []
103+ if sink not in self ._vertex_indices :
104+ return []
105+
106+ src_id = self ._vertex_indices [source ]
107+ sink_id = self ._vertex_indices [sink ]
100108 try :
101109 return self ._graph .get_all_edge_data (src_id , sink_id )
102110 except rx .NoEdgeBetweenNodes :
@@ -118,19 +126,19 @@ def list_connections(
118126 )
119127
120128 def remove_node (self , node_id : NodeId ) -> None :
121- if node_id not in self ._node_id_to_rx_id_map :
129+ if node_id not in self ._vertex_indices :
122130 return
123131
124- rx_idx = self ._node_id_to_rx_id_map [node_id ]
132+ rx_idx = self ._vertex_indices [node_id ]
125133 self ._graph .remove_node (rx_idx )
126134
127- del self ._node_id_to_rx_id_map [node_id ]
135+ del self ._vertex_indices [node_id ]
128136
129137 def replace_all_out_tb_connections (
130138 self , source : NodeId , new_connections : Sequence [tuple [NodeId , RDMAConnection ]]
131139 ) -> None :
132140 for conn_idx in self ._graph .out_edge_indices (
133- self ._node_id_to_rx_id_map [source ]
141+ self ._vertex_indices [source ]
134142 ):
135143 if isinstance (self ._graph .get_edge_data_by_index (conn_idx ), RDMAConnection ):
136144 self ._graph .remove_edge_from_index (conn_idx )
@@ -141,7 +149,7 @@ def remove_connection(
141149 self , source : NodeId , sink : NodeId , edge : SocketConnection | RDMAConnection
142150 ) -> None :
143151 for conn_idx in self ._graph .edge_indices_from_endpoints (
144- self ._node_id_to_rx_id_map [source ], self ._node_id_to_rx_id_map [sink ]
152+ self ._vertex_indices [source ], self ._vertex_indices [sink ]
145153 ):
146154 if self ._graph .get_edge_data_by_index (conn_idx ) == edge :
147155 self ._graph .remove_edge_from_index (conn_idx )
@@ -179,7 +187,7 @@ def get_cycles_tb(self) -> list[list[NodeId]]:
179187
180188 def get_subgraph_from_nodes (self , nodes : list [NodeId ]) -> "Topology" :
181189 node_idxs = [node for node in nodes ]
182- rx_idxs = [self ._node_id_to_rx_id_map [idx ] for idx in node_idxs ]
190+ rx_idxs = [self ._vertex_indices [idx ] for idx in node_idxs ]
183191 topology = Topology ()
184192 for rx_idx in rx_idxs :
185193 topology .add_node (self ._graph [rx_idx ])
@@ -190,7 +198,7 @@ def get_subgraph_from_nodes(self, nodes: list[NodeId]) -> "Topology":
190198
191199 def is_thunderbolt_cycle (self , cycle : list [NodeId ]) -> bool :
192200 node_idxs = [node for node in cycle ]
193- rx_idxs = [self ._node_id_to_rx_id_map [idx ] for idx in node_idxs ]
201+ rx_idxs = [self ._vertex_indices [idx ] for idx in node_idxs ]
194202 for rid in rx_idxs :
195203 for neighbor_rid in self ._graph .neighbors (rid ):
196204 if neighbor_rid not in rx_idxs :
0 commit comments