We do have a heap-optimized version of Prim’s algorithm, but in most cases the standard Prim’s algorithm and Kruskal’s algorithm are sufficient.
- For dense graphs, Prim’s algorithm is generally more efficient.
- For sparse graphs, Kruskal’s algorithm is the better choice.
Prim’s Algorithm – O(n^2)
def prim(n, g):
"""
Prim's algorithm on an undirected graph given as an adjacency matrix
Args:
n (int): number of vertices (1..n are used)
g (list[list[int]]): (n+1) x (n+1) adjacency matrix.
Use g[i][j] = INF if there's no edge; g[i][i] can be 0.
Returns: total weight of the MST,
or INF if the graph is disconnected.
"""
INF = 10**9
dist = [INF] * (n + 1) # distance to current MST
st = [False] * (n + 1) # whether vertex is already in MST
res = 0
for i in range(n):
t = -1
for j in range(1, n + 1):
if not st[j] and (t == -1 or dist[j] < dist[t]):
t = j
if i and dist[t] == INF:
return INF # not connected
if i:
res += dist[t]
st[t] = True
for j in range(1, n + 1):
# relax edges from t
if not st[j] and g[t][j] < dist[j]:
dist[j] = g[t][j]
return res
Kruskal’s Algorithm – O(m*log(m))
def kruskal(n, edges):
"""
n: number of vertices (1..n)
edges: list of (a, b, w)
returns: total weight of MST or float('inf') if not connected
"""
parent = list(range(n + 1))
def find(x):
if parent[x] != x:
parent[x] = find(parent[x])
return parent[x]
def union(a, b):
ra, rb = find(a), find(b)
if ra == rb:
return False
parent[ra] = rb
return True
edges.sort(key=lambda e: e[2])
res, cnt = 0, 0
for a, b, w in edges:
if union(a, b):
res += w
cnt += 1
if cnt == n - 1:
break
return res if cnt == n - 1 else -1