diff --git a/src/services/importer.py b/src/services/importer.py new file mode 100644 index 0000000..92d0633 --- /dev/null +++ b/src/services/importer.py @@ -0,0 +1,87 @@ +import datetime +from collections import Counter +from pathlib import Path +from sqlalchemy import func +from sqlalchemy.orm import Session + +from src.models.transaction import Transaction +from src.services.csv_reader import read_csv +from src.services.normalizer import normalize_description + + +class ImportService: + def __init__(self, session: Session): + self.session = session + + def import_csv( + self, + file_path: Path, + account_id: int, + column_map: dict, + amount_logic: str = "signed", + ) -> dict: + rows = read_csv(file_path) + imported = 0 + duplicates = 0 + + # Track how many times each key appears in the current batch so that + # legitimate repeated transactions (same date/amount/description) in a + # single file are all imported, while a second import of the same file + # correctly detects every row as a duplicate. + batch_counts: Counter = Counter() + + for row in rows: + date_val = self._parse_date(row[column_map["date"]]) + amount_val = self._parse_amount(row, column_map, amount_logic) + raw_desc = row[column_map["description"]].strip() + description = normalize_description(raw_desc) + source_cat = row.get(column_map.get("source_category", "__missing__"), "").strip() or None + + if date_val is None or amount_val is None: + continue + + key = (date_val, amount_val, raw_desc, account_id) + batch_counts[key] += 1 + + # Count how many matching transactions already exist in the DB + existing_count = ( + self.session.query(func.count(Transaction.id)) + .filter_by(date=date_val, amount=amount_val, raw_description=raw_desc, account_id=account_id) + .scalar() + ) + + # If the DB already has at least as many copies as we've seen + # (including this one) in the current batch, it's a duplicate. + if existing_count >= batch_counts[key]: + duplicates += 1 + continue + + txn = Transaction( + date=date_val, + amount=amount_val, + description=description, + raw_description=raw_desc, + account_id=account_id, + source_category=source_cat, + ) + self.session.add(txn) + imported += 1 + + self.session.commit() + return {"imported": imported, "duplicates": duplicates, "total_rows": len(rows)} + + def _parse_date(self, val: str) -> datetime.date | None: + val = val.strip().strip('"') + for fmt in ("%m/%d/%Y", "%Y-%m-%d", "%m-%d-%Y"): + try: + return datetime.datetime.strptime(val, fmt).date() + except ValueError: + continue + return None + + def _parse_amount(self, row: dict, column_map: dict, logic: str) -> float | None: + try: + if logic == "signed": + return float(row[column_map["amount"]].strip().strip('"').replace(",", "")) + except (ValueError, KeyError): + return None diff --git a/tests/services/test_importer.py b/tests/services/test_importer.py new file mode 100644 index 0000000..70fcb82 --- /dev/null +++ b/tests/services/test_importer.py @@ -0,0 +1,110 @@ +# tests/services/test_importer.py +import datetime +from pathlib import Path +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from src.db import Base +from src.models import * +from src.seed import seed_categories +from src.services.importer import ImportService + +RAWDATA = Path(__file__).parent.parent.parent / "rawdata" + + +def make_session(): + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + return Session(engine) + + +def setup_chase_account(session): + member = HouseholdMember(name="Andrew", relationship="self") + session.add(member) + session.flush() + account = Account(name="Chase Freedom", institution="Chase", account_type="credit", owner_id=member.id) + session.add(account) + session.flush() + seed_categories(session) + return account + + +def setup_checking_account(session): + member = HouseholdMember(name="Andrew", relationship="self") + session.add(member) + session.flush() + account = Account(name="WF Checking", institution="Wells Fargo", account_type="checking", owner_id=member.id, is_shared=True) + session.add(account) + session.flush() + seed_categories(session) + return account + + +def test_import_chase_csv(): + session = make_session() + account = setup_chase_account(session) + column_map = { + "date": "Transaction Date", + "amount": "Amount", + "description": "Description", + "source_category": "Category", + } + svc = ImportService(session) + result = svc.import_csv( + RAWDATA / "Chase0372_Activity20260101_20260210_20260210.CSV", + account_id=account.id, + column_map=column_map, + amount_logic="signed", + ) + assert result["imported"] > 0 + assert result["duplicates"] == 0 + txns = session.query(Transaction).all() + assert len(txns) == result["imported"] + sephora = [t for t in txns if "SEPHORA" in t.description] + assert len(sephora) == 1 + assert float(sephora[0].amount) == -75.00 + + +def test_import_checking_csv(): + session = make_session() + account = setup_checking_account(session) + column_map = { + "date": 0, + "amount": 1, + "description": 4, + } + svc = ImportService(session) + result = svc.import_csv( + RAWDATA / "Checking1.csv", + account_id=account.id, + column_map=column_map, + amount_logic="signed", + ) + assert result["imported"] > 0 + txns = session.query(Transaction).all() + assert len(txns) > 50 + + +def test_duplicate_detection(): + session = make_session() + account = setup_chase_account(session) + column_map = { + "date": "Transaction Date", + "amount": "Amount", + "description": "Description", + } + svc = ImportService(session) + result1 = svc.import_csv( + RAWDATA / "Chase0372_Activity20260101_20260210_20260210.CSV", + account_id=account.id, + column_map=column_map, + amount_logic="signed", + ) + result2 = svc.import_csv( + RAWDATA / "Chase0372_Activity20260101_20260210_20260210.CSV", + account_id=account.id, + column_map=column_map, + amount_logic="signed", + ) + assert result2["duplicates"] == result1["imported"] + assert result2["imported"] == 0