diff --git a/src/services/categorizer.py b/src/services/categorizer.py new file mode 100644 index 0000000..6db8410 --- /dev/null +++ b/src/services/categorizer.py @@ -0,0 +1,57 @@ +import re +from dataclasses import dataclass +from typing import Protocol, runtime_checkable + +from sqlalchemy.orm import Session + +from src.models.rule import CategorizationRule +from src.models.transaction import Transaction + + +@dataclass +class CategoryResult: + category_id: int + tag: str | None = None + attributed_to_id: int | None = None + + +@runtime_checkable +class Categorizer(Protocol): + def categorize(self, transaction: Transaction) -> CategoryResult | None: ... + + +class RuleBasedCategorizer: + def __init__(self, session: Session): + self.session = session + self._rules: list[CategorizationRule] | None = None + + def _load_rules(self) -> list[CategorizationRule]: + if self._rules is None: + self._rules = ( + self.session.query(CategorizationRule) + .order_by(CategorizationRule.priority) + .all() + ) + return self._rules + + def invalidate_cache(self) -> None: + self._rules = None + + def categorize(self, transaction: Transaction) -> CategoryResult | None: + rules = self._load_rules() + desc = transaction.description.upper() + + for rule in rules: + patterns = [p.strip() for p in rule.pattern.split("|")] + for pattern in patterns: + if re.search(re.escape(pattern.upper()), desc): + tag = rule.tag_override + if tag is None and rule.category: + tag = rule.category.default_tag + return CategoryResult( + category_id=rule.category_id, + tag=tag, + attributed_to_id=rule.attributed_to_id, + ) + + return None diff --git a/tests/services/test_categorizer.py b/tests/services/test_categorizer.py new file mode 100644 index 0000000..8e14f92 --- /dev/null +++ b/tests/services/test_categorizer.py @@ -0,0 +1,110 @@ +import datetime + +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from src.db import Base +from src.models import * +from src.seed import seed_categories +from src.services.categorizer import RuleBasedCategorizer, Categorizer + + +def make_session(): + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + return Session(engine) + + +def test_categorizer_protocol(): + """RuleBasedCategorizer implements Categorizer protocol.""" + session = make_session() + cat = RuleBasedCategorizer(session) + assert isinstance(cat, Categorizer) + + +def test_match_simple_pattern(): + session = make_session() + seed_categories(session) + groceries = session.query(Category).filter_by(name="Groceries").one() + rule = CategorizationRule(pattern="PUBLIX", category_id=groceries.id, priority=10) + session.add(rule) + session.commit() + + cat = RuleBasedCategorizer(session) + txn = Transaction(date=datetime.date(2026, 1, 15), amount=-44.90, description="PUBLIX #1716", account_id=1) + result = cat.categorize(txn) + assert result is not None + assert result.category_id == groceries.id + + +def test_match_pipe_separated_pattern(): + session = make_session() + seed_categories(session) + groceries = session.query(Category).filter_by(name="Groceries").one() + rule = CategorizationRule(pattern="PUBLIX|ALDI|PIGGLY WIGGLY", category_id=groceries.id, priority=10) + session.add(rule) + session.commit() + + cat = RuleBasedCategorizer(session) + for desc in ["PUBLIX #1716", "ALDI 76180 BEAUFORT", "PIGGLY WIGGLY #286"]: + txn = Transaction(date=datetime.date(2026, 1, 15), amount=-20.00, description=desc, account_id=1) + result = cat.categorize(txn) + assert result is not None, f"Failed to match: {desc}" + assert result.category_id == groceries.id + + +def test_no_match_returns_none(): + session = make_session() + seed_categories(session) + cat = RuleBasedCategorizer(session) + txn = Transaction(date=datetime.date(2026, 1, 15), amount=-10.00, description="UNKNOWN MERCHANT", account_id=1) + result = cat.categorize(txn) + assert result is None + + +def test_priority_ordering(): + session = make_session() + seed_categories(session) + groceries = session.query(Category).filter_by(name="Groceries").one() + shopping = session.query(Category).filter_by(name="Shopping").one() + rule1 = CategorizationRule(pattern="WAL-MART", category_id=groceries.id, priority=1) + rule2 = CategorizationRule(pattern="WAL", category_id=shopping.id, priority=10) + session.add_all([rule1, rule2]) + session.commit() + + cat = RuleBasedCategorizer(session) + txn = Transaction(date=datetime.date(2026, 1, 15), amount=-48.52, description="WAL-MART #7181", account_id=1) + result = cat.categorize(txn) + assert result.category_id == groceries.id + + +def test_tag_override(): + session = make_session() + seed_categories(session) + dining = session.query(Category).filter_by(name="Dining Out").one() + rule = CategorizationRule(pattern="CHICK-FIL-A", category_id=dining.id, tag_override="needs", priority=10) + session.add(rule) + session.commit() + + cat = RuleBasedCategorizer(session) + txn = Transaction(date=datetime.date(2026, 1, 15), amount=-5.39, description="CHICK-FIL-A #01476", account_id=1) + result = cat.categorize(txn) + assert result.tag == "needs" + + +def test_person_attribution(): + session = make_session() + seed_categories(session) + donna = HouseholdMember(name="Donna", relationship="wife") + session.add(donna) + session.flush() + income = session.query(Category).filter_by(name="Income").one() + rule = CategorizationRule(pattern="OASISBATCH PAYROLL", category_id=income.id, attributed_to_id=donna.id, priority=10) + session.add(rule) + session.commit() + + cat = RuleBasedCategorizer(session) + txn = Transaction(date=datetime.date(2026, 1, 9), amount=1087.83, description="OASISBATCH PAYROLL 260109 DONNA CONLON", account_id=1) + result = cat.categorize(txn) + assert result.category_id == income.id + assert result.attributed_to_id == donna.id