# tests/test_seed.py from sqlalchemy import create_engine from sqlalchemy.orm import Session from src.db import Base from src.seed import seed_categories, seed_household, seed_default_rules, DEFAULT_CATEGORIES from src.models import Category, HouseholdMember, CategorizationRule def make_session(): engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) return Session(engine) def test_seed_categories(): session = make_session() seed_categories(session) cats = session.query(Category).all() assert len(cats) == len(DEFAULT_CATEGORIES) names = {c.name for c in cats} assert "Groceries" in names assert "Transfer" in names assert "Income" in names def test_seed_categories_idempotent(): session = make_session() seed_categories(session) seed_categories(session) cats = session.query(Category).all() assert len(cats) == len(DEFAULT_CATEGORIES) def test_seed_household(): session = make_session() seed_household(session, "Andrew", "self") members = session.query(HouseholdMember).all() assert len(members) == 1 assert members[0].name == "Andrew" def test_seed_default_rules(): session = make_session() seed_categories(session) member = HouseholdMember(name="Andrew", relationship="self") session.add(member) session.commit() seed_default_rules(session) rules = session.query(CategorizationRule).all() assert len(rules) > 20 # Check some known rules netflix = [r for r in rules if "Netflix" in r.pattern] assert len(netflix) == 1 assert netflix[0].category.name == "Subscriptions" def test_seed_default_rules_idempotent(): session = make_session() seed_categories(session) member = HouseholdMember(name="Andrew", relationship="self") session.add(member) session.commit() seed_default_rules(session) count1 = session.query(CategorizationRule).count() seed_default_rules(session) count2 = session.query(CategorizationRule).count() assert count1 == count2 def test_seed_default_rules_skips_person_rules_without_members(): session = make_session() seed_categories(session) # Don't add any household members seed_default_rules(session) rules = session.query(CategorizationRule).all() # Rules requiring person attribution should be skipped person_rules = [r for r in rules if r.attributed_to_id is not None] assert len(person_rules) == 0