← run

py-02-csv-groupby

1.000
8/8 tests· data
Challenge · difficulty 2/5
# CSV Group Sum

Implement a file **`solution.py`** containing a function:

```python
def group_sum(csv_text: str, key_col: str, val_col: str) -> dict[str, float]:
    ...
```

Parse `csv_text` as CSV and return a dictionary mapping each distinct value in
column `key_col` to the **sum** of column `val_col` for all rows with that key.

Rules:

- The **first non-blank line** is the header row naming the columns.
- Columns are comma-separated. `key_col` and `val_col` name two of those columns.
- Values in `val_col` are numeric (int or float); sum them as floats.
- **Ignore blank lines** anywhere in the input (including trailing newlines).
- If the input has only a header (or is empty), return an empty dict `{}`.
- You may assume `key_col` and `val_col` exist in the header.

Example:

```python
csv_text = "name,amount\\nalice,10\\nbob,5\\nalice,2.5\\n"
group_sum(csv_text, "name", "amount") == {"alice": 12.5, "bob": 5.0}
```
tests/test_group_sum.py
from solution import group_sum


def test_basic_sum():
    csv_text = "name,amount\nalice,10\nbob,5\nalice,2.5\n"
    assert group_sum(csv_text, "name", "amount") == {"alice": 12.5, "bob": 5.0}


def test_returns_floats():
    csv_text = "k,v\na,1\na,2\n"
    out = group_sum(csv_text, "k", "v")
    assert out == {"a": 3.0}
    assert all(isinstance(v, float) for v in out.values())


def test_ignores_blank_lines():
    csv_text = "\n\nname,amount\nalice,1\n\nbob,2\n\n\n"
    assert group_sum(csv_text, "name", "amount") == {"alice": 1.0, "bob": 2.0}


def test_header_only():
    assert group_sum("name,amount\n", "name", "amount") == {}


def test_empty_input():
    assert group_sum("", "name", "amount") == {}
    assert group_sum("\n\n", "name", "amount") == {}


def test_column_order_independent():
    csv_text = "amount,name\n10,x\n5,y\n3,x\n"
    assert group_sum(csv_text, "name", "amount") == {"x": 13.0, "y": 5.0}


def test_three_columns_picks_right_ones():
    csv_text = "region,product,sales\neast,a,100\nwest,b,50\neast,c,25\n"
    assert group_sum(csv_text, "region", "sales") == {"east": 125.0, "west": 50.0}


def test_negative_and_float_values():
    csv_text = "g,n\nx,-1.5\nx,0.5\ny,4\n"
    assert group_sum(csv_text, "g", "n") == {"x": -1.0, "y": 4.0}
System prompt
You are an expert programmer. Solve the task exactly as specified. Output your solution as fenced code blocks using the required file name(s) and the exact function/type signatures requested. Prefer correctness; do not include prose outside code unless asked.
Proposed solution
```python
def group_sum(csv_text: str, key_col: str, val_col: str) -> dict[str, float]:
    # Handle empty input or input with only whitespace
    if not csv_text.strip():
        return {}
    
    # Split the text into lines and filter out blank lines
    lines = [line for line in csv_text.split('\n') if line.strip()]
    
    # If no lines after filtering, return empty dict
    if not lines:
        return {}
    
    # Parse header to get column indices
    header = lines[0].split(',')
    
    # Find indices of key_col and val_col
    try:
        key_idx = header.index(key_col)
        val_idx = header.index(val_col)
    except ValueError:
        # According to the problem statement, we can assume columns exist,
        # but just in case, return empty dict if not found
        return {}
    
    # Initialize result dictionary
    result = {}
    
    # Process data rows (skip header)
    for line in lines[1:]:
        # Skip blank lines (already filtered, but just to be safe)
        if not line.strip():
            continue
            
        # Split the line into columns
        columns = line.split(',')
        
        # Get key and value
        key = columns[key_idx]
        try:
            value = float(columns[val_idx])
        except (ValueError, IndexError):
            # Skip rows with invalid values
            continue
        
        # Add to the sum for this key
        if key in result:
            result[key] += value
        else:
            result[key] = value
    
    return result
```

Test output (stdout)

........                                                                 [100%]
8 passed in 0.01s