83 lines
2.4 KiB
Python
83 lines
2.4 KiB
Python
# tests/test_seed.py
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import Session
|
|
|
|
from src.db import Base
|
|
from src.seed import seed_categories, seed_household, seed_default_rules, DEFAULT_CATEGORIES
|
|
from src.models import Category, HouseholdMember, CategorizationRule
|
|
|
|
|
|
def make_session():
|
|
engine = create_engine("sqlite:///:memory:")
|
|
Base.metadata.create_all(engine)
|
|
return Session(engine)
|
|
|
|
|
|
def test_seed_categories():
|
|
session = make_session()
|
|
seed_categories(session)
|
|
cats = session.query(Category).all()
|
|
assert len(cats) == len(DEFAULT_CATEGORIES)
|
|
names = {c.name for c in cats}
|
|
assert "Groceries" in names
|
|
assert "Transfer" in names
|
|
assert "Income" in names
|
|
|
|
|
|
def test_seed_categories_idempotent():
|
|
session = make_session()
|
|
seed_categories(session)
|
|
seed_categories(session)
|
|
cats = session.query(Category).all()
|
|
assert len(cats) == len(DEFAULT_CATEGORIES)
|
|
|
|
|
|
def test_seed_household():
|
|
session = make_session()
|
|
seed_household(session, "Andrew", "self")
|
|
members = session.query(HouseholdMember).all()
|
|
assert len(members) == 1
|
|
assert members[0].name == "Andrew"
|
|
|
|
|
|
def test_seed_default_rules():
|
|
session = make_session()
|
|
seed_categories(session)
|
|
member = HouseholdMember(name="Andrew", relationship="self")
|
|
session.add(member)
|
|
session.commit()
|
|
|
|
seed_default_rules(session)
|
|
rules = session.query(CategorizationRule).all()
|
|
assert len(rules) > 20
|
|
|
|
# Check some known rules
|
|
netflix = [r for r in rules if "Netflix" in r.pattern]
|
|
assert len(netflix) == 1
|
|
assert netflix[0].category.name == "Subscriptions"
|
|
|
|
|
|
def test_seed_default_rules_idempotent():
|
|
session = make_session()
|
|
seed_categories(session)
|
|
member = HouseholdMember(name="Andrew", relationship="self")
|
|
session.add(member)
|
|
session.commit()
|
|
|
|
seed_default_rules(session)
|
|
count1 = session.query(CategorizationRule).count()
|
|
seed_default_rules(session)
|
|
count2 = session.query(CategorizationRule).count()
|
|
assert count1 == count2
|
|
|
|
|
|
def test_seed_default_rules_skips_person_rules_without_members():
|
|
session = make_session()
|
|
seed_categories(session)
|
|
# Don't add any household members
|
|
seed_default_rules(session)
|
|
rules = session.query(CategorizationRule).all()
|
|
# Rules requiring person attribution should be skipped
|
|
person_rules = [r for r in rules if r.attributed_to_id is not None]
|
|
assert len(person_rules) == 0
|