Files
SpendingAnalysis/tests/services/test_categorizer.py
andy b7746ece4f feat: add rule-based categorization engine with protocol interface
Implement Categorizer protocol and RuleBasedCategorizer service that
matches transactions against pipe-separated patterns ordered by priority,
with support for tag overrides and household member attribution.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 14:44:16 -05:00

111 lines
4.0 KiB
Python

import datetime
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from src.db import Base
from src.models import *
from src.seed import seed_categories
from src.services.categorizer import RuleBasedCategorizer, Categorizer
def make_session():
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
return Session(engine)
def test_categorizer_protocol():
"""RuleBasedCategorizer implements Categorizer protocol."""
session = make_session()
cat = RuleBasedCategorizer(session)
assert isinstance(cat, Categorizer)
def test_match_simple_pattern():
session = make_session()
seed_categories(session)
groceries = session.query(Category).filter_by(name="Groceries").one()
rule = CategorizationRule(pattern="PUBLIX", category_id=groceries.id, priority=10)
session.add(rule)
session.commit()
cat = RuleBasedCategorizer(session)
txn = Transaction(date=datetime.date(2026, 1, 15), amount=-44.90, description="PUBLIX #1716", account_id=1)
result = cat.categorize(txn)
assert result is not None
assert result.category_id == groceries.id
def test_match_pipe_separated_pattern():
session = make_session()
seed_categories(session)
groceries = session.query(Category).filter_by(name="Groceries").one()
rule = CategorizationRule(pattern="PUBLIX|ALDI|PIGGLY WIGGLY", category_id=groceries.id, priority=10)
session.add(rule)
session.commit()
cat = RuleBasedCategorizer(session)
for desc in ["PUBLIX #1716", "ALDI 76180 BEAUFORT", "PIGGLY WIGGLY #286"]:
txn = Transaction(date=datetime.date(2026, 1, 15), amount=-20.00, description=desc, account_id=1)
result = cat.categorize(txn)
assert result is not None, f"Failed to match: {desc}"
assert result.category_id == groceries.id
def test_no_match_returns_none():
session = make_session()
seed_categories(session)
cat = RuleBasedCategorizer(session)
txn = Transaction(date=datetime.date(2026, 1, 15), amount=-10.00, description="UNKNOWN MERCHANT", account_id=1)
result = cat.categorize(txn)
assert result is None
def test_priority_ordering():
session = make_session()
seed_categories(session)
groceries = session.query(Category).filter_by(name="Groceries").one()
shopping = session.query(Category).filter_by(name="Shopping").one()
rule1 = CategorizationRule(pattern="WAL-MART", category_id=groceries.id, priority=1)
rule2 = CategorizationRule(pattern="WAL", category_id=shopping.id, priority=10)
session.add_all([rule1, rule2])
session.commit()
cat = RuleBasedCategorizer(session)
txn = Transaction(date=datetime.date(2026, 1, 15), amount=-48.52, description="WAL-MART #7181", account_id=1)
result = cat.categorize(txn)
assert result.category_id == groceries.id
def test_tag_override():
session = make_session()
seed_categories(session)
dining = session.query(Category).filter_by(name="Dining Out").one()
rule = CategorizationRule(pattern="CHICK-FIL-A", category_id=dining.id, tag_override="needs", priority=10)
session.add(rule)
session.commit()
cat = RuleBasedCategorizer(session)
txn = Transaction(date=datetime.date(2026, 1, 15), amount=-5.39, description="CHICK-FIL-A #01476", account_id=1)
result = cat.categorize(txn)
assert result.tag == "needs"
def test_person_attribution():
session = make_session()
seed_categories(session)
donna = HouseholdMember(name="Donna", relationship="wife")
session.add(donna)
session.flush()
income = session.query(Category).filter_by(name="Income").one()
rule = CategorizationRule(pattern="OASISBATCH PAYROLL", category_id=income.id, attributed_to_id=donna.id, priority=10)
session.add(rule)
session.commit()
cat = RuleBasedCategorizer(session)
txn = Transaction(date=datetime.date(2026, 1, 9), amount=1087.83, description="OASISBATCH PAYROLL 260109 DONNA CONLON", account_id=1)
result = cat.categorize(txn)
assert result.category_id == income.id
assert result.attributed_to_id == donna.id