← run

py-02-csv-groupby

1.000
8/8 tests· data
Challenge · difficulty 2/5
# CSV Group Sum

Implement a file **`solution.py`** containing a function:

```python
def group_sum(csv_text: str, key_col: str, val_col: str) -> dict[str, float]:
    ...
```

Parse `csv_text` as CSV and return a dictionary mapping each distinct value in
column `key_col` to the **sum** of column `val_col` for all rows with that key.

Rules:

- The **first non-blank line** is the header row naming the columns.
- Columns are comma-separated. `key_col` and `val_col` name two of those columns.
- Values in `val_col` are numeric (int or float); sum them as floats.
- **Ignore blank lines** anywhere in the input (including trailing newlines).
- If the input has only a header (or is empty), return an empty dict `{}`.
- You may assume `key_col` and `val_col` exist in the header.

Example:

```python
csv_text = "name,amount\\nalice,10\\nbob,5\\nalice,2.5\\n"
group_sum(csv_text, "name", "amount") == {"alice": 12.5, "bob": 5.0}
```
tests/test_group_sum.py
from solution import group_sum


def test_basic_sum():
    csv_text = "name,amount\nalice,10\nbob,5\nalice,2.5\n"
    assert group_sum(csv_text, "name", "amount") == {"alice": 12.5, "bob": 5.0}


def test_returns_floats():
    csv_text = "k,v\na,1\na,2\n"
    out = group_sum(csv_text, "k", "v")
    assert out == {"a": 3.0}
    assert all(isinstance(v, float) for v in out.values())


def test_ignores_blank_lines():
    csv_text = "\n\nname,amount\nalice,1\n\nbob,2\n\n\n"
    assert group_sum(csv_text, "name", "amount") == {"alice": 1.0, "bob": 2.0}


def test_header_only():
    assert group_sum("name,amount\n", "name", "amount") == {}


def test_empty_input():
    assert group_sum("", "name", "amount") == {}
    assert group_sum("\n\n", "name", "amount") == {}


def test_column_order_independent():
    csv_text = "amount,name\n10,x\n5,y\n3,x\n"
    assert group_sum(csv_text, "name", "amount") == {"x": 13.0, "y": 5.0}


def test_three_columns_picks_right_ones():
    csv_text = "region,product,sales\neast,a,100\nwest,b,50\neast,c,25\n"
    assert group_sum(csv_text, "region", "sales") == {"east": 125.0, "west": 50.0}


def test_negative_and_float_values():
    csv_text = "g,n\nx,-1.5\nx,0.5\ny,4\n"
    assert group_sum(csv_text, "g", "n") == {"x": -1.0, "y": 4.0}
System prompt
You are an expert programmer. Solve the task exactly as specified. Output your solution as fenced code blocks using the required file name(s) and the exact function/type signatures requested. Prefer correctness; do not include prose outside code unless asked.
Proposed solution
```python
import csv
from io import StringIO

def group_sum(csv_text: str, key_col: str, val_col: str) -> dict[str, float]:
    """
    Parse CSV text, sum values in val_col for each unique key in key_col,
    and return a dictionary mapping keys to sums.
    """
    # Filter out blank lines and strip whitespace
    lines = [line.strip() for line in csv_text.splitlines() if line.strip()]
    
    # If there are no data lines (only header or empty), return empty dict
    if len(lines) < 2:
        return {}
    
    # The first non-blank line is the header
    header = lines[0]
    
    # Find indices of key_col and val_col
    # We use a simple split by comma for the header
    header_parts = header.split(',')
    
    try:
        key_index = header_parts.index(key_col)
        val_index = header_parts.index(val_col)
    except ValueError:
        # According to the problem statement, we can assume these columns exist
        raise ValueError(f"Column '{key_col}' or '{val_col}' not found in header")

    result = {}
    
    # Process data rows
    for line in lines[1:]:
        # Split the row by comma
        row_parts = line.split(',')
        
        # Ensure the row has enough columns
        if len(row_parts) <= max(key_index, val_index):
            continue
            
        key = row_parts[key_index]
        val_str = row_parts[val_index]
        
        try:
            val = float(val_str)
        except ValueError:
            # Skip rows where val_col is not numeric
            continue
            
        if key in result:
            result[key] += val
        else:
            result[key] = val
            
    return result
```

Test output (stdout)

........                                                                 [100%]
8 passed in 0.01s