← run

py-11-dijkstra

1.000
8/8 tests· algorithms
Challenge · difficulty 5/5
# Dijkstra shortest paths (heapq)

Implement **`solution.py`** with:

```python
def dijkstra(graph: dict[str, list[tuple[str, float]]], start: str) -> dict[str, float]:
    ...
```

Compute the **shortest-path distance** from `start` to every reachable node in a
weighted **directed** graph.

- `graph[u]` is a list of `(v, weight)` edges from `u` to `v`. Weights are
  non-negative.
- Return a dict mapping each **reachable** node to its minimum total distance from
  `start`. `start` itself maps to `0.0`.
- **Unreachable nodes must be omitted** from the result (do not include them with
  `inf`).
- A node that appears only as an edge target (never as a key in `graph`) is a
  valid node with no outgoing edges.
- Use the standard library only — implement Dijkstra's algorithm with
  **`heapq`** as the priority queue. Do **not** use networkx or any third-party
  library here.

Example:

```python
g = {
    "a": [("b", 1.0), ("c", 4.0)],
    "b": [("c", 2.0), ("d", 5.0)],
    "c": [("d", 1.0)],
    "d": [],
}
dijkstra(g, "a")
# {"a": 0.0, "b": 1.0, "c": 3.0, "d": 4.0}

dijkstra({"a": [("b", 2.0)], "b": [], "island": [("a", 1.0)]}, "a")
# {"a": 0.0, "b": 2.0}   # "island" is unreachable from "a", omitted
```
tests/test_dijkstra.py
import math

import pytest

from solution import dijkstra


def test_basic_multipath():
    g = {
        "a": [("b", 1.0), ("c", 4.0)],
        "b": [("c", 2.0), ("d", 5.0)],
        "c": [("d", 1.0)],
        "d": [],
    }
    out = dijkstra(g, "a")
    assert out == {"a": 0.0, "b": 1.0, "c": 3.0, "d": 4.0}


def test_start_distance_zero():
    g = {"a": [("b", 7.0)], "b": []}
    out = dijkstra(g, "a")
    assert out["a"] == 0.0


def test_unreachable_omitted():
    g = {"a": [("b", 2.0)], "b": [], "island": [("a", 1.0)]}
    out = dijkstra(g, "a")
    assert out == {"a": 0.0, "b": 2.0}
    assert "island" not in out


def test_chooses_cheaper_route():
    # direct a->c is 10, but a->b->c is 3
    g = {
        "a": [("b", 1.0), ("c", 10.0)],
        "b": [("c", 2.0)],
        "c": [],
    }
    out = dijkstra(g, "a")
    assert out["c"] == pytest.approx(3.0)


def test_single_node_no_edges():
    out = dijkstra({"a": []}, "a")
    assert out == {"a": 0.0}


def test_target_only_node_has_no_outgoing():
    # "z" is only an edge target, never a key
    g = {"a": [("z", 5.0)]}
    out = dijkstra(g, "a")
    assert out == {"a": 0.0, "z": 5.0}


def test_zero_weight_edges():
    g = {"a": [("b", 0.0)], "b": [("c", 0.0)], "c": []}
    out = dijkstra(g, "a")
    assert out == {"a": 0.0, "b": 0.0, "c": 0.0}


def test_larger_graph_relaxation():
    g = {
        "s": [("a", 4.0), ("b", 1.0)],
        "b": [("a", 2.0), ("c", 5.0)],
        "a": [("c", 1.0)],
        "c": [("t", 3.0)],
        "t": [],
    }
    out = dijkstra(g, "s")
    # s->b(1)->a(3)->c(4)->t(7)
    assert out["a"] == pytest.approx(3.0)
    assert out["c"] == pytest.approx(4.0)
    assert out["t"] == pytest.approx(7.0)
    assert all(math.isfinite(v) for v in out.values())
System prompt
You are an expert programmer. Solve the task exactly as specified. Output your solution as fenced code blocks using the required file name(s) and the exact function/type signatures requested. Prefer correctness; do not include prose outside code unless asked.
Proposed solution
```python
import heapq

def dijkstra(graph: dict[str, list[tuple[str, float]]], start: str) -> dict[str, float]:
    distances = {node: float('inf') for node in graph}
    distances[start] = 0.0
    heap = [(0.0, start)]

    while heap:
        current_dist, current_node = heapq.heappop(heap)
        if current_dist > distances[current_node]:
            continue
        for neighbor, weight in graph.get(current_node, []):
            distance = current_dist + weight
            if distance < distances.get(neighbor, float('inf')):
                distances[neighbor] = distance
                heapq.heappush(heap, (distance, neighbor))

    # Filter out nodes that remain unreachable (distance still inf)
    result = {node: dist for node, dist in distances.items() if dist != float('inf')}
    return result
```

Test output (stdout)

........                                                                 [100%]
8 passed in 0.01s