diff --git a/src/seed.py b/src/seed.py new file mode 100644 index 0000000..efa00fa --- /dev/null +++ b/src/seed.py @@ -0,0 +1,128 @@ +# src/seed.py +from sqlalchemy.orm import Session +from src.models import Category, CategorizationRule, HouseholdMember + +DEFAULT_CATEGORIES = [ + ("Income", None, None), + ("Housing", "needs", None), + ("Groceries", "needs", None), + ("Dining Out", "wants", None), + ("Transportation", "needs", None), + ("Gas", "needs", None), + ("Utilities", "needs", None), + ("Insurance", "needs", None), + ("Healthcare", "needs", None), + ("Entertainment", "wants", None), + ("Shopping", "wants", None), + ("Subscriptions", "wants", None), + ("Personal Care", "wants", None), + ("Family", "needs", None), + ("Gifts & Donations", "wants", None), + ("Debt Payment", "needs", None), + ("Savings", "savings", None), + ("Transfer", None, None), + ("Travel", "wants", None), + ("Home", "needs", None), + ("Professional Services", "needs", None), + ("Uncategorized", None, None), +] + + +def seed_categories(session: Session) -> None: + existing = {c.name for c in session.query(Category).all()} + for name, tag, icon in DEFAULT_CATEGORIES: + if name not in existing: + session.add(Category(name=name, default_tag=tag, icon=icon)) + session.commit() + + +DEFAULT_RULES = [ + # Income + ("CIMTECHNIQUES", "Income", None, "Andrew", 1), + ("OASISBATCH PAYROLL", "Income", None, "Donna", 1), + # Transfers + ("CHASE CREDIT CRD EPAY", "Transfer", None, None, 1), + ("CAPITAL ONE TRANSFER", "Transfer", None, None, 1), + ("AMEX EPAYMENT", "Transfer", None, "Donna", 1), + ("WAY2SAVE SAVINGS", "Transfer", None, None, 1), + ("AMZ_STORECRD_PMT", "Transfer", None, None, 1), + # Housing + ("FREEDOM MTG PYMTS", "Housing", "needs", None, 5), + ("Beaufort County", "Housing", "needs", None, 5), # property tax + # Utilities + ("DOMINION ENERGY", "Utilities", "needs", None, 5), + ("BEAUFORTJASP.*UTILITY", "Utilities", "needs", None, 5), + # Insurance + ("FARM BUREAU INS", "Insurance", "needs", None, 5), + # Transportation + ("VW CREDIT", "Transportation", "needs", None, 5), + ("TIDAL WAVE", "Transportation", "needs", None, 5), # car wash + # Debt + ("WSFS LOAN", "Debt Payment", "needs", None, 5), + ("AXOSADVSERV", "Debt Payment", "needs", None, 5), + # Groceries + ("PUBLIX|ALDI|PIGGLY WIGGLY|WAL-MART|WM SUPERCENTER|WALMART|TRADER JOE", "Groceries", "needs", None, 10), + ("HELLOFRESH", "Groceries", "needs", None, 10), + # Dining Out + ("CHICK-FIL-A|KFC|MCDONALD|TACO BELL|PANDA EXPRESS|LITTLE CAESARS|CULVERS|WAYBACK|SONIC DRIVE|JERSEY MIKE|COOK OUT|SURCHEROS|CHICKEN SALAD", "Dining Out", "wants", None, 10), + # Gas + ("CIRCLE K", "Gas", "needs", None, 10), + # Subscriptions + ("Netflix.com", "Subscriptions", "wants", None, 10), + ("Prime Video", "Subscriptions", "wants", None, 10), + ("APPLE.COM/BILL", "Subscriptions", "wants", None, 10), + ("Patreon", "Subscriptions", "wants", None, 10), + ("CBB WORLD", "Subscriptions", "wants", None, 10), + ("CLAUDE.AI", "Subscriptions", "wants", None, 10), + ("OPENROUTER", "Subscriptions", "wants", None, 10), + ("STORJ", "Subscriptions", "wants", None, 10), + # Healthcare + ("TR COUNSELING|RESOURCE MEDICAL|LABORATORY CORP|BEAUFORT PHCY|MED.*UNI MED|CVS/PHARMACY|WALGREENS", "Healthcare", "needs", None, 10), + # Home + ("ORKIN|LOWES|SHERWIN-WILLIAMS|GRAYCO HARDWARE|MAKERWORLD", "Home", "needs", None, 10), + # Professional Services + ("GEOARM|ALARMCLUB", "Professional Services", "needs", None, 10), + # Shopping (lower priority - catch-all for Amazon etc.) + ("AMAZON", "Shopping", "wants", None, 20), + # Personal Care + ("BARBERS OF THE LOW", "Personal Care", "wants", None, 10), +] + + +def seed_default_rules(session: Session) -> None: + """Seed default categorization rules. Only adds rules if none exist yet.""" + existing_count = session.query(CategorizationRule).count() + if existing_count > 0: + return # Don't overwrite user's rules + + for pattern, cat_name, tag_override, person_name, priority in DEFAULT_RULES: + category = session.query(Category).filter_by(name=cat_name).first() + if not category: + continue + + attributed_to = None + if person_name: + attributed_to = session.query(HouseholdMember).filter_by(name=person_name).first() + if not attributed_to: + continue # Skip rules requiring a person that doesn't exist yet + + rule = CategorizationRule( + pattern=pattern, + category_id=category.id, + tag_override=tag_override, + attributed_to_id=attributed_to.id if attributed_to else None, + priority=priority, + ) + session.add(rule) + + session.commit() + + +def seed_household(session: Session, name: str, relationship: str) -> HouseholdMember: + existing = session.query(HouseholdMember).filter_by(name=name).first() + if existing: + return existing + member = HouseholdMember(name=name, relationship=relationship) + session.add(member) + session.commit() + return member diff --git a/tests/test_seed.py b/tests/test_seed.py new file mode 100644 index 0000000..4dc2a7d --- /dev/null +++ b/tests/test_seed.py @@ -0,0 +1,82 @@ +# tests/test_seed.py +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from src.db import Base +from src.seed import seed_categories, seed_household, seed_default_rules, DEFAULT_CATEGORIES +from src.models import Category, HouseholdMember, CategorizationRule + + +def make_session(): + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + return Session(engine) + + +def test_seed_categories(): + session = make_session() + seed_categories(session) + cats = session.query(Category).all() + assert len(cats) == len(DEFAULT_CATEGORIES) + names = {c.name for c in cats} + assert "Groceries" in names + assert "Transfer" in names + assert "Income" in names + + +def test_seed_categories_idempotent(): + session = make_session() + seed_categories(session) + seed_categories(session) + cats = session.query(Category).all() + assert len(cats) == len(DEFAULT_CATEGORIES) + + +def test_seed_household(): + session = make_session() + seed_household(session, "Andrew", "self") + members = session.query(HouseholdMember).all() + assert len(members) == 1 + assert members[0].name == "Andrew" + + +def test_seed_default_rules(): + session = make_session() + seed_categories(session) + member = HouseholdMember(name="Andrew", relationship="self") + session.add(member) + session.commit() + + seed_default_rules(session) + rules = session.query(CategorizationRule).all() + assert len(rules) > 20 + + # Check some known rules + netflix = [r for r in rules if "Netflix" in r.pattern] + assert len(netflix) == 1 + assert netflix[0].category.name == "Subscriptions" + + +def test_seed_default_rules_idempotent(): + session = make_session() + seed_categories(session) + member = HouseholdMember(name="Andrew", relationship="self") + session.add(member) + session.commit() + + seed_default_rules(session) + count1 = session.query(CategorizationRule).count() + seed_default_rules(session) + count2 = session.query(CategorizationRule).count() + assert count1 == count2 + + +def test_seed_default_rules_skips_person_rules_without_members(): + session = make_session() + seed_categories(session) + # Don't add any household members + seed_default_rules(session) + rules = session.query(CategorizationRule).all() + # Rules requiring person attribution should be skipped + person_rules = [r for r in rules if r.attributed_to_id is not None] + assert len(person_rules) == 0