Skip to content

Commit c0cf560

Browse files
committed
dedup connections
1 parent 5912e9f commit c0cf560

File tree

2 files changed

+38
-27
lines changed

2 files changed

+38
-27
lines changed

src/exo/shared/topology.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import contextlib
22
from collections.abc import Sequence
3+
from dataclasses import dataclass
34
from typing import Iterable
45

56
import rustworkx as rx
@@ -16,12 +17,11 @@ class TopologySnapshot(BaseModel):
1617
model_config = ConfigDict(frozen=True, extra="forbid")
1718

1819

20+
@dataclass
1921
class Topology:
20-
def __init__(self) -> None:
21-
self._graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = (
22-
rx.PyDiGraph()
23-
)
24-
self._node_id_to_rx_id_map: dict[NodeId, int] = dict()
22+
# the _graph can be used as a int -> NodeId map.
23+
_graph = rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection]()
24+
_vertex_indices = dict[NodeId, int]()
2525

2626
def to_snapshot(self) -> TopologySnapshot:
2727
return TopologySnapshot(
@@ -43,60 +43,68 @@ def from_snapshot(cls, snapshot: TopologySnapshot) -> "Topology":
4343
return topology
4444

4545
def add_node(self, node: NodeId) -> None:
46-
if node in self._node_id_to_rx_id_map:
46+
if node in self._vertex_indices:
4747
return
4848
rx_id = self._graph.add_node(node)
49-
self._node_id_to_rx_id_map[node] = rx_id
49+
self._vertex_indices[node] = rx_id
5050
self._graph[rx_id] = node
5151

5252
def node_is_leaf(self, node_id: NodeId) -> bool:
5353
return (
54-
node_id in self._node_id_to_rx_id_map
55-
and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) <= 1
54+
node_id in self._vertex_indices
55+
and len(self._graph.neighbors(self._vertex_indices[node_id])) <= 1
5656
)
5757

5858
def neighbours(self, node_id: NodeId) -> list[NodeId]:
5959
return [
6060
self._graph[rx_id]
61-
for rx_id in self._graph.neighbors(self._node_id_to_rx_id_map[node_id])
61+
for rx_id in self._graph.neighbors(self._vertex_indices[node_id])
6262
]
6363

6464
def out_edges(
6565
self, node_id: NodeId
6666
) -> Iterable[tuple[NodeId, SocketConnection | RDMAConnection]]:
67-
if node_id not in self._node_id_to_rx_id_map:
67+
if node_id not in self._vertex_indices:
6868
return []
6969
return (
7070
(self._graph[nid], conn)
7171
for _, nid, conn in self._graph.out_edges(
72-
self._node_id_to_rx_id_map[node_id]
72+
self._vertex_indices[node_id]
7373
)
7474
)
7575

7676
def contains_node(self, node_id: NodeId) -> bool:
77-
return node_id in self._node_id_to_rx_id_map
77+
return node_id in self._vertex_indices
7878

7979
def add_connection(
8080
self,
8181
source: NodeId,
8282
sink: NodeId,
8383
connection: SocketConnection | RDMAConnection,
8484
) -> None:
85-
if source not in self._node_id_to_rx_id_map:
85+
if connection in self.get_all_connections_between(source, sink):
86+
return
87+
88+
if source not in self._vertex_indices:
8689
self.add_node(source)
87-
if sink not in self._node_id_to_rx_id_map:
90+
if sink not in self._vertex_indices:
8891
self.add_node(sink)
8992

90-
src_id = self._node_id_to_rx_id_map[source]
91-
sink_id = self._node_id_to_rx_id_map[sink]
93+
src_id = self._vertex_indices[source]
94+
sink_id = self._vertex_indices[sink]
9295

93-
self._graph.add_edge(src_id, sink_id, connection)
96+
_ = self._graph.add_edge(src_id, sink_id, connection)
9497

9598
def get_all_connections_between(
9699
self, source: NodeId, sink: NodeId
97100
) -> Iterable[SocketConnection | RDMAConnection]:
98-
src_id = self._node_id_to_rx_id_map[source]
99-
sink_id = self._node_id_to_rx_id_map[sink]
101+
if source not in self._vertex_indices:
102+
return []
103+
if sink not in self._vertex_indices:
104+
return []
105+
106+
src_id = self._vertex_indices[source]
107+
sink_id = self._vertex_indices[sink]
100108
try:
101109
return self._graph.get_all_edge_data(src_id, sink_id)
102110
except rx.NoEdgeBetweenNodes:
@@ -118,19 +126,19 @@ def list_connections(
118126
)
119127

120128
def remove_node(self, node_id: NodeId) -> None:
121-
if node_id not in self._node_id_to_rx_id_map:
129+
if node_id not in self._vertex_indices:
122130
return
123131

124-
rx_idx = self._node_id_to_rx_id_map[node_id]
132+
rx_idx = self._vertex_indices[node_id]
125133
self._graph.remove_node(rx_idx)
126134

127-
del self._node_id_to_rx_id_map[node_id]
135+
del self._vertex_indices[node_id]
128136

129137
def replace_all_out_tb_connections(
130138
self, source: NodeId, new_connections: Sequence[tuple[NodeId, RDMAConnection]]
131139
) -> None:
132140
for conn_idx in self._graph.out_edge_indices(
133-
self._node_id_to_rx_id_map[source]
141+
self._vertex_indices[source]
134142
):
135143
if isinstance(self._graph.get_edge_data_by_index(conn_idx), RDMAConnection):
136144
self._graph.remove_edge_from_index(conn_idx)
@@ -141,7 +149,7 @@ def remove_connection(
141149
self, source: NodeId, sink: NodeId, edge: SocketConnection | RDMAConnection
142150
) -> None:
143151
for conn_idx in self._graph.edge_indices_from_endpoints(
144-
self._node_id_to_rx_id_map[source], self._node_id_to_rx_id_map[sink]
152+
self._vertex_indices[source], self._vertex_indices[sink]
145153
):
146154
if self._graph.get_edge_data_by_index(conn_idx) == edge:
147155
self._graph.remove_edge_from_index(conn_idx)
@@ -179,7 +187,7 @@ def get_cycles_tb(self) -> list[list[NodeId]]:
179187

180188
def get_subgraph_from_nodes(self, nodes: list[NodeId]) -> "Topology":
181189
node_idxs = [node for node in nodes]
182-
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
190+
rx_idxs = [self._vertex_indices[idx] for idx in node_idxs]
183191
topology = Topology()
184192
for rx_idx in rx_idxs:
185193
topology.add_node(self._graph[rx_idx])
@@ -190,7 +198,7 @@ def get_subgraph_from_nodes(self, nodes: list[NodeId]) -> "Topology":
190198

191199
def is_thunderbolt_cycle(self, cycle: list[NodeId]) -> bool:
192200
node_idxs = [node for node in cycle]
193-
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
201+
rx_idxs = [self._vertex_indices[idx] for idx in node_idxs]
194202
for rid in rx_idxs:
195203
for neighbor_rid in self._graph.neighbors(rid):
196204
if neighbor_rid not in rx_idxs:

src/exo/shared/types/topology.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,8 @@ class LinkType(str, Enum):
2525
class SocketConnection(FrozenModel):
2626
sink_multiaddr: Multiaddr
2727

28+
def __hash__(self):
29+
return hash(self.sink_multiaddr.ip_address)
30+
2831
def is_thunderbolt(self) -> bool:
2932
return str(self.sink_multiaddr.ipv4_address).startswith("169.254")

0 commit comments

Comments
 (0)