# tests/services/test_forecasting.py import datetime from sqlalchemy import create_engine from sqlalchemy.orm import Session from src.db import Base from src.models import * from src.seed import seed_categories from src.services.forecasting import ForecastingService def make_session(): engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) return Session(engine) def make_forecast_data(session): seed_categories(session) member = HouseholdMember(name="Andrew", relationship="self") session.add(member) session.flush() account = Account(name="Chase", institution="Chase", account_type="credit", owner_id=member.id) session.add(account) session.flush() groceries = session.query(Category).filter_by(name="Groceries").one() # 3 months of grocery spending: $500, $600, $700 (trending up) for month, total in [(10, 500), (11, 600), (12, 700)]: session.add(Transaction( date=datetime.date(2025, month, 15), amount=-total, description="GROCERIES", account_id=account.id, category_id=groceries.id, tag="needs", )) session.commit() return account def test_monthly_forecast(): session = make_session() make_forecast_data(session) svc = ForecastingService(session) forecast = svc.forecast_month() groceries = [f for f in forecast if f["category"] == "Groceries"] assert len(groceries) == 1 assert groceries[0]["projected"] < 0 def test_annual_forecast(): session = make_session() make_forecast_data(session) svc = ForecastingService(session) forecast = svc.forecast_year() assert len(forecast) > 0 total = sum(f["projected"] for f in forecast) assert total < 0 def test_what_if_removes_recurring(): session = make_session() make_forecast_data(session) svc = ForecastingService(session) base = svc.forecast_month() adjusted = svc.forecast_month(exclude_descriptions=["GROCERIES"]) base_total = sum(f["projected"] for f in base) adj_total = sum(f["projected"] for f in adjusted) assert adj_total > base_total # Less spending when we exclude