Files
SpendingAnalysis/tests/services/test_forecasting.py
andy db06108d2b feat: complete v1 implementation - all services, UI views, and tests
Adds remaining files from parallel development:
- Services: analysis, csv_reader, forecasting, normalizer, recurring
- UI: recurring_view, settings_view, sidebar, themes (dark/light)
- Tests: analysis, csv_reader, forecasting, import_categorize,
  normalizer, recurring, integration
- App entry point (main.py) and CLAUDE.md

52 tests passing across all modules.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 14:57:46 -05:00

71 lines
2.2 KiB
Python

# tests/services/test_forecasting.py
import datetime
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from src.db import Base
from src.models import *
from src.seed import seed_categories
from src.services.forecasting import ForecastingService
def make_session():
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
return Session(engine)
def make_forecast_data(session):
seed_categories(session)
member = HouseholdMember(name="Andrew", relationship="self")
session.add(member)
session.flush()
account = Account(name="Chase", institution="Chase", account_type="credit", owner_id=member.id)
session.add(account)
session.flush()
groceries = session.query(Category).filter_by(name="Groceries").one()
# 3 months of grocery spending: $500, $600, $700 (trending up)
for month, total in [(10, 500), (11, 600), (12, 700)]:
session.add(Transaction(
date=datetime.date(2025, month, 15),
amount=-total,
description="GROCERIES",
account_id=account.id,
category_id=groceries.id,
tag="needs",
))
session.commit()
return account
def test_monthly_forecast():
session = make_session()
make_forecast_data(session)
svc = ForecastingService(session)
forecast = svc.forecast_month()
groceries = [f for f in forecast if f["category"] == "Groceries"]
assert len(groceries) == 1
assert groceries[0]["projected"] < 0
def test_annual_forecast():
session = make_session()
make_forecast_data(session)
svc = ForecastingService(session)
forecast = svc.forecast_year()
assert len(forecast) > 0
total = sum(f["projected"] for f in forecast)
assert total < 0
def test_what_if_removes_recurring():
session = make_session()
make_forecast_data(session)
svc = ForecastingService(session)
base = svc.forecast_month()
adjusted = svc.forecast_month(exclude_descriptions=["GROCERIES"])
base_total = sum(f["projected"] for f in base)
adj_total = sum(f["projected"] for f in adjusted)
assert adj_total > base_total # Less spending when we exclude