Files
SpendingAnalysis/tests/test_seed.py
2026-02-10 14:48:13 -05:00

83 lines
2.4 KiB
Python

# tests/test_seed.py
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from src.db import Base
from src.seed import seed_categories, seed_household, seed_default_rules, DEFAULT_CATEGORIES
from src.models import Category, HouseholdMember, CategorizationRule
def make_session():
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
return Session(engine)
def test_seed_categories():
session = make_session()
seed_categories(session)
cats = session.query(Category).all()
assert len(cats) == len(DEFAULT_CATEGORIES)
names = {c.name for c in cats}
assert "Groceries" in names
assert "Transfer" in names
assert "Income" in names
def test_seed_categories_idempotent():
session = make_session()
seed_categories(session)
seed_categories(session)
cats = session.query(Category).all()
assert len(cats) == len(DEFAULT_CATEGORIES)
def test_seed_household():
session = make_session()
seed_household(session, "Andrew", "self")
members = session.query(HouseholdMember).all()
assert len(members) == 1
assert members[0].name == "Andrew"
def test_seed_default_rules():
session = make_session()
seed_categories(session)
member = HouseholdMember(name="Andrew", relationship="self")
session.add(member)
session.commit()
seed_default_rules(session)
rules = session.query(CategorizationRule).all()
assert len(rules) > 20
# Check some known rules
netflix = [r for r in rules if "Netflix" in r.pattern]
assert len(netflix) == 1
assert netflix[0].category.name == "Subscriptions"
def test_seed_default_rules_idempotent():
session = make_session()
seed_categories(session)
member = HouseholdMember(name="Andrew", relationship="self")
session.add(member)
session.commit()
seed_default_rules(session)
count1 = session.query(CategorizationRule).count()
seed_default_rules(session)
count2 = session.query(CategorizationRule).count()
assert count1 == count2
def test_seed_default_rules_skips_person_rules_without_members():
session = make_session()
seed_categories(session)
# Don't add any household members
seed_default_rules(session)
rules = session.query(CategorizationRule).all()
# Rules requiring person attribution should be skipped
person_rules = [r for r in rules if r.attributed_to_id is not None]
assert len(person_rules) == 0