feat: import wizard view with column mapping and file drop

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-10 14:52:33 -05:00
parent 112d3aa99d
commit cfde077055
2 changed files with 731 additions and 1 deletions

699
src/ui/import_view.py Normal file
View File

@@ -0,0 +1,699 @@
"""Import wizard view with three stages: file selection, column mapping, and results."""
from __future__ import annotations
import json
from pathlib import Path
from PySide6.QtCore import Qt, Signal, QThread
from PySide6.QtWidgets import (
QCheckBox,
QComboBox,
QFileDialog,
QFormLayout,
QFrame,
QGroupBox,
QHBoxLayout,
QHeaderView,
QLabel,
QLineEdit,
QProgressBar,
QPushButton,
QStackedWidget,
QTableWidget,
QTableWidgetItem,
QVBoxLayout,
QWidget,
)
from sqlalchemy.orm import Session
from src.models.account import Account
from src.models.csv_mapping import CsvMapping
from src.models.household import HouseholdMember
from src.services.csv_reader import detect_format
from src.services.importer import ImportService
# ---------------------------------------------------------------------------
# Stage 1: File Selection
# ---------------------------------------------------------------------------
class _DropZone(QFrame):
"""A large drop zone that accepts CSV files via drag-and-drop."""
file_selected = Signal(str)
def __init__(self, parent: QWidget | None = None):
super().__init__(parent)
self.setAcceptDrops(True)
self.setObjectName("drop-zone")
self.setMinimumSize(400, 200)
self.setFrameShape(QFrame.Shape.StyledPanel)
self.setStyleSheet(
"#drop-zone {"
" border: 2px dashed #aaa;"
" border-radius: 12px;"
" background: #f9f9f9;"
"}"
"#drop-zone[dragOver='true'] {"
" border-color: #4a90d9;"
" background: #eaf2fd;"
"}"
)
layout = QVBoxLayout(self)
layout.setAlignment(Qt.AlignmentFlag.AlignCenter)
icon_label = QLabel("Drag & Drop CSV File Here")
icon_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
icon_label.setStyleSheet("font-size: 18px; color: #666;")
layout.addWidget(icon_label)
or_label = QLabel("or")
or_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
or_label.setStyleSheet("color: #999;")
layout.addWidget(or_label)
browse_btn = QPushButton("Browse...")
browse_btn.setFixedWidth(120)
browse_btn.clicked.connect(self._browse)
layout.addWidget(browse_btn, alignment=Qt.AlignmentFlag.AlignCenter)
self._file_label = QLabel("")
self._file_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
self._file_label.setStyleSheet("color: #333; font-weight: bold; margin-top: 8px;")
layout.addWidget(self._file_label)
# -- drag/drop support --------------------------------------------------
def dragEnterEvent(self, event):
if event.mimeData().hasUrls():
for url in event.mimeData().urls():
if url.toLocalFile().lower().endswith(".csv"):
event.acceptProposedAction()
self.setProperty("dragOver", True)
self.style().unpolish(self)
self.style().polish(self)
return
event.ignore()
def dragLeaveEvent(self, event):
self.setProperty("dragOver", False)
self.style().unpolish(self)
self.style().polish(self)
def dropEvent(self, event):
self.setProperty("dragOver", False)
self.style().unpolish(self)
self.style().polish(self)
for url in event.mimeData().urls():
path = url.toLocalFile()
if path.lower().endswith(".csv"):
self._show_file(path)
self.file_selected.emit(path)
return
# -- browse dialog ------------------------------------------------------
def _browse(self):
path, _ = QFileDialog.getOpenFileName(
self, "Select CSV File", "", "CSV Files (*.csv);;All Files (*)"
)
if path:
self._show_file(path)
self.file_selected.emit(path)
def _show_file(self, path: str):
self._file_label.setText(Path(path).name)
def reset(self):
self._file_label.setText("")
class _FileSelectionPage(QWidget):
"""Stage 1 widget."""
file_selected = Signal(str)
def __init__(self, parent: QWidget | None = None):
super().__init__(parent)
layout = QVBoxLayout(self)
layout.setAlignment(Qt.AlignmentFlag.AlignCenter)
title = QLabel("Import Transactions")
title.setAlignment(Qt.AlignmentFlag.AlignCenter)
title.setStyleSheet("font-size: 24px; font-weight: bold; margin-bottom: 16px;")
layout.addWidget(title)
self._drop_zone = _DropZone()
self._drop_zone.file_selected.connect(self.file_selected.emit)
layout.addWidget(self._drop_zone)
def reset(self):
self._drop_zone.reset()
# ---------------------------------------------------------------------------
# Stage 2: Column Mapping
# ---------------------------------------------------------------------------
class _NewAccountGroup(QGroupBox):
"""Inline fields for creating a new account."""
def __init__(self, session: Session, parent: QWidget | None = None):
super().__init__("New Account", parent)
self.session = session
self.setVisible(False)
form = QFormLayout(self)
self.name_edit = QLineEdit()
form.addRow("Name:", self.name_edit)
self.institution_edit = QLineEdit()
form.addRow("Institution:", self.institution_edit)
self.type_combo = QComboBox()
self.type_combo.addItems(["checking", "credit"])
form.addRow("Type:", self.type_combo)
self.owner_combo = QComboBox()
self.owner_combo.addItem("(none)", None)
for member in self.session.query(HouseholdMember).all():
self.owner_combo.addItem(member.name, member.id)
form.addRow("Owner:", self.owner_combo)
def create_account(self) -> Account | None:
"""Create the account in the DB and return it, or None if name is blank."""
name = self.name_edit.text().strip()
if not name:
return None
acct = Account(
name=name,
institution=self.institution_edit.text().strip(),
account_type=self.type_combo.currentText(),
owner_id=self.owner_combo.currentData(),
)
self.session.add(acct)
self.session.flush()
return acct
def reset(self):
self.name_edit.clear()
self.institution_edit.clear()
self.type_combo.setCurrentIndex(0)
self.owner_combo.setCurrentIndex(0)
self.setVisible(False)
class _ColumnMappingPage(QWidget):
"""Stage 2 widget: preview table + column mapping controls."""
mapping_confirmed = Signal(dict) # emits full config dict
TARGET_FIELDS = ["(skip)", "Date", "Amount", "Description", "Source Category"]
def __init__(self, session: Session, parent: QWidget | None = None):
super().__init__(parent)
self.session = session
self._format_info: dict | None = None
self._file_path: str = ""
self._column_combos: list[QComboBox] = []
outer = QVBoxLayout(self)
# --- preview table ---
preview_label = QLabel("CSV Preview (first 5 rows)")
preview_label.setStyleSheet("font-weight: bold;")
outer.addWidget(preview_label)
self._table = QTableWidget()
self._table.setEditTriggers(QTableWidget.EditTrigger.NoEditTriggers)
self._table.setMaximumHeight(180)
outer.addWidget(self._table)
# --- column mapping combos row ---
mapping_label = QLabel("Map each column to a target field:")
mapping_label.setStyleSheet("font-weight: bold; margin-top: 8px;")
outer.addWidget(mapping_label)
self._mapping_row = QHBoxLayout()
outer.addLayout(self._mapping_row)
# --- amount logic ---
amount_row = QHBoxLayout()
amount_lbl = QLabel("Amount logic:")
self._amount_logic = QComboBox()
self._amount_logic.addItems(["signed", "separate debit/credit"])
amount_row.addWidget(amount_lbl)
amount_row.addWidget(self._amount_logic)
amount_row.addStretch()
outer.addLayout(amount_row)
# --- account selection ---
acct_row = QHBoxLayout()
acct_lbl = QLabel("Account:")
self._account_combo = QComboBox()
self._new_account_btn = QPushButton("New Account...")
self._new_account_btn.clicked.connect(self._toggle_new_account)
acct_row.addWidget(acct_lbl)
acct_row.addWidget(self._account_combo)
acct_row.addWidget(self._new_account_btn)
acct_row.addStretch()
outer.addLayout(acct_row)
self._new_account_group = _NewAccountGroup(session)
outer.addWidget(self._new_account_group)
# --- save mapping ---
save_row = QHBoxLayout()
self._save_check = QCheckBox("Save this mapping")
self._mapping_name_edit = QLineEdit()
self._mapping_name_edit.setPlaceholderText("Mapping name")
self._mapping_name_edit.setEnabled(False)
self._save_check.toggled.connect(self._mapping_name_edit.setEnabled)
save_row.addWidget(self._save_check)
save_row.addWidget(self._mapping_name_edit)
save_row.addStretch()
outer.addLayout(save_row)
# --- confirm button ---
self._confirm_btn = QPushButton("Import")
self._confirm_btn.setFixedWidth(140)
self._confirm_btn.clicked.connect(self._on_confirm)
outer.addWidget(self._confirm_btn, alignment=Qt.AlignmentFlag.AlignRight)
outer.addStretch()
# -- public helpers -----------------------------------------------------
def load_file(self, file_path: str):
"""Populate the preview table and mapping combos from the selected CSV."""
self._file_path = file_path
self._format_info = detect_format(Path(file_path))
self._populate_preview()
self._populate_account_combo()
self._check_saved_mapping()
def reset(self):
self._table.clear()
self._table.setRowCount(0)
self._table.setColumnCount(0)
self._clear_mapping_combos()
self._save_check.setChecked(False)
self._mapping_name_edit.clear()
self._new_account_group.reset()
self._format_info = None
self._file_path = ""
# -- internal -----------------------------------------------------------
def _populate_preview(self):
info = self._format_info
if info is None:
return
preview = info["preview"]
headers = info.get("headers")
col_count = len(headers) if headers else info.get("column_count", 0)
self._table.setColumnCount(col_count)
self._table.setRowCount(len(preview))
if headers:
self._table.setHorizontalHeaderLabels(headers)
else:
self._table.setHorizontalHeaderLabels([f"Col {i}" for i in range(col_count)])
for r, row in enumerate(preview):
for c, cell in enumerate(row):
self._table.setItem(r, c, QTableWidgetItem(cell))
self._table.horizontalHeader().setSectionResizeMode(QHeaderView.ResizeMode.Stretch)
# Build mapping combo boxes
self._clear_mapping_combos()
col_labels = headers if headers else [f"Col {i}" for i in range(col_count)]
for _i, col_name in enumerate(col_labels):
combo = QComboBox()
combo.addItems(self.TARGET_FIELDS)
# Try smart defaults
lower = col_name.lower()
if "date" in lower:
combo.setCurrentText("Date")
elif "amount" in lower or "debit" in lower or "credit" in lower:
combo.setCurrentText("Amount")
elif "desc" in lower or "memo" in lower or "narr" in lower:
combo.setCurrentText("Description")
elif "categ" in lower or "type" in lower:
combo.setCurrentText("Source Category")
vbox = QVBoxLayout()
lbl = QLabel(col_name)
lbl.setAlignment(Qt.AlignmentFlag.AlignCenter)
lbl.setStyleSheet("font-size: 11px;")
vbox.addWidget(lbl)
vbox.addWidget(combo)
self._mapping_row.addLayout(vbox)
self._column_combos.append(combo)
def _clear_mapping_combos(self):
self._column_combos.clear()
while self._mapping_row.count():
item = self._mapping_row.takeAt(0)
if item.layout():
while item.layout().count():
child = item.layout().takeAt(0)
if child.widget():
child.widget().deleteLater()
elif item.widget():
item.widget().deleteLater()
def _populate_account_combo(self):
self._account_combo.clear()
accounts = self.session.query(Account).all()
for acct in accounts:
label = f"{acct.name} ({acct.institution})"
self._account_combo.addItem(label, acct.id)
if not accounts:
self._account_combo.addItem("(no accounts -- create one below)", None)
def _check_saved_mapping(self):
"""Check if a saved CsvMapping matches this file's fingerprint."""
info = self._format_info
if info is None:
return
headers = info.get("headers")
if headers is None:
return
fingerprint = ",".join(headers)
mapping = (
self.session.query(CsvMapping)
.filter_by(fingerprint=fingerprint)
.first()
)
if mapping:
self._apply_saved_mapping(mapping)
def _apply_saved_mapping(self, mapping: CsvMapping):
"""Pre-populate combo boxes from a saved mapping."""
try:
col_map = json.loads(mapping.column_map)
except (json.JSONDecodeError, TypeError):
return
# Reverse map: target_field -> column_key
reverse = {}
for col_key, target in col_map.items():
reverse[target] = col_key
info = self._format_info
headers = info.get("headers") if info else None
if headers is None:
return
for i, combo in enumerate(self._column_combos):
header = headers[i] if i < len(headers) else str(i)
if header in reverse:
# The reverse map tells us which target field this column should map to
pass
# Actually, column_map stores {target -> column_key}, so look up by header
for target_field, col_key in col_map.items():
if col_key == header or col_key == str(i):
display = target_field.title()
idx = combo.findText(display)
if idx >= 0:
combo.setCurrentIndex(idx)
# Set amount logic
idx = self._amount_logic.findText(mapping.amount_logic)
if idx >= 0:
self._amount_logic.setCurrentIndex(idx)
# Set account
acct_idx = self._account_combo.findData(mapping.account_id)
if acct_idx >= 0:
self._account_combo.setCurrentIndex(acct_idx)
def _toggle_new_account(self):
visible = not self._new_account_group.isVisible()
self._new_account_group.setVisible(visible)
self._new_account_btn.setText("Cancel" if visible else "New Account...")
def _build_column_map(self) -> dict[str, str] | None:
"""Build the column_map dict expected by ImportService.
Returns a dict like {"date": "Date Column", "amount": "Amount Column", ...}
or None if required fields are missing.
"""
info = self._format_info
if info is None:
return None
headers = info.get("headers")
col_labels = headers if headers else [str(i) for i in range(info.get("column_count", 0))]
result: dict[str, str] = {}
target_to_key = {
"Date": "date",
"Amount": "amount",
"Description": "description",
"Source Category": "source_category",
}
for i, combo in enumerate(self._column_combos):
text = combo.currentText()
if text == "(skip)":
continue
key = target_to_key.get(text)
if key:
# The column_map value is what ImportService uses to look up
# the row dict. For header CSVs this is the header string;
# for headerless CSVs this is the integer index.
col_id = col_labels[i] if headers else int(col_labels[i])
result[key] = col_id
# Require at least date, amount, description
for required in ("date", "amount", "description"):
if required not in result:
return None
return result
def _on_confirm(self):
column_map = self._build_column_map()
if column_map is None:
return # TODO: show validation message
# Resolve account
account_id = self._account_combo.currentData()
if self._new_account_group.isVisible():
acct = self._new_account_group.create_account()
if acct is not None:
account_id = acct.id
self._populate_account_combo()
if account_id is None:
return # TODO: show validation message
# Save mapping if requested
if self._save_check.isChecked():
mapping_name = self._mapping_name_edit.text().strip()
if mapping_name and self._format_info:
headers = self._format_info.get("headers")
fingerprint = ",".join(headers) if headers else ""
saved = CsvMapping(
name=mapping_name,
fingerprint=fingerprint,
column_map=json.dumps(column_map),
amount_logic=self._amount_logic.currentText(),
account_id=account_id,
)
self.session.add(saved)
self.session.flush()
config = {
"file_path": self._file_path,
"account_id": account_id,
"column_map": column_map,
"amount_logic": self._amount_logic.currentText(),
}
self.mapping_confirmed.emit(config)
# ---------------------------------------------------------------------------
# Stage 3: Import Results
# ---------------------------------------------------------------------------
class _ImportWorker(QThread):
"""Runs the import in a background thread so the UI stays responsive."""
finished = Signal(dict)
def __init__(self, session: Session, config: dict, parent=None):
super().__init__(parent)
self.session = session
self.config = config
def run(self):
svc = ImportService(self.session)
result = svc.import_csv(
file_path=Path(self.config["file_path"]),
account_id=self.config["account_id"],
column_map=self.config["column_map"],
amount_logic=self.config["amount_logic"],
)
self.finished.emit(result)
class _ImportResultsPage(QWidget):
"""Stage 3 widget: progress indicator and results summary."""
import_another = Signal()
view_transactions = Signal()
def __init__(self, parent: QWidget | None = None):
super().__init__(parent)
layout = QVBoxLayout(self)
layout.setAlignment(Qt.AlignmentFlag.AlignCenter)
title = QLabel("Importing...")
title.setObjectName("results-title")
title.setAlignment(Qt.AlignmentFlag.AlignCenter)
title.setStyleSheet("font-size: 22px; font-weight: bold; margin-bottom: 12px;")
layout.addWidget(title)
self._title = title
self._progress = QProgressBar()
self._progress.setRange(0, 0) # indeterminate
self._progress.setFixedWidth(400)
layout.addWidget(self._progress, alignment=Qt.AlignmentFlag.AlignCenter)
self._results_frame = QFrame()
self._results_frame.setVisible(False)
results_layout = QVBoxLayout(self._results_frame)
self._imported_label = QLabel()
self._imported_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
results_layout.addWidget(self._imported_label)
self._duplicates_label = QLabel()
self._duplicates_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
results_layout.addWidget(self._duplicates_label)
self._total_label = QLabel()
self._total_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
results_layout.addWidget(self._total_label)
btn_row = QHBoxLayout()
btn_row.setAlignment(Qt.AlignmentFlag.AlignCenter)
another_btn = QPushButton("Import Another")
another_btn.clicked.connect(self.import_another.emit)
btn_row.addWidget(another_btn)
view_btn = QPushButton("View Transactions")
view_btn.clicked.connect(self.view_transactions.emit)
btn_row.addWidget(view_btn)
results_layout.addLayout(btn_row)
layout.addWidget(self._results_frame)
def show_progress(self):
self._title.setText("Importing...")
self._progress.setVisible(True)
self._progress.setRange(0, 0)
self._results_frame.setVisible(False)
def show_results(self, results: dict):
self._progress.setVisible(False)
self._results_frame.setVisible(True)
imported = results.get("imported", 0)
duplicates = results.get("duplicates", 0)
total = results.get("total_rows", 0)
uncategorized = imported # conservative default; real count requires query
self._title.setText("Import Complete")
self._imported_label.setText(f"Imported: {imported} transactions")
self._duplicates_label.setText(f"Duplicates skipped: {duplicates}")
self._total_label.setText(f"Total rows processed: {total}")
def reset(self):
self.show_progress()
# ---------------------------------------------------------------------------
# ImportView: the top-level stacked widget
# ---------------------------------------------------------------------------
class ImportView(QWidget):
"""Three-stage import wizard exposed to MainWindow."""
import_complete = Signal()
navigate_to = Signal(str)
STAGE_FILE = 0
STAGE_MAPPING = 1
STAGE_RESULTS = 2
def __init__(self, session: Session, parent: QWidget | None = None):
super().__init__(parent)
self.session = session
self._worker: _ImportWorker | None = None
layout = QVBoxLayout(self)
layout.setContentsMargins(24, 24, 24, 24)
self._stack = QStackedWidget()
layout.addWidget(self._stack)
# Stage 1
self._file_page = _FileSelectionPage()
self._file_page.file_selected.connect(self._on_file_selected)
self._stack.addWidget(self._file_page)
# Stage 2
self._mapping_page = _ColumnMappingPage(session)
self._mapping_page.mapping_confirmed.connect(self._on_mapping_confirmed)
self._stack.addWidget(self._mapping_page)
# Stage 3
self._results_page = _ImportResultsPage()
self._results_page.import_another.connect(self._go_to_file_selection)
self._results_page.view_transactions.connect(
lambda: self.navigate_to.emit("transactions")
)
self._stack.addWidget(self._results_page)
self._stack.setCurrentIndex(self.STAGE_FILE)
# -- navigation ---------------------------------------------------------
def _on_file_selected(self, path: str):
self._mapping_page.load_file(path)
self._stack.setCurrentIndex(self.STAGE_MAPPING)
def _on_mapping_confirmed(self, config: dict):
self._stack.setCurrentIndex(self.STAGE_RESULTS)
self._results_page.show_progress()
self._run_import(config)
def _go_to_file_selection(self):
self._file_page.reset()
self._mapping_page.reset()
self._results_page.reset()
self._stack.setCurrentIndex(self.STAGE_FILE)
# -- import execution ---------------------------------------------------
def _run_import(self, config: dict):
self._worker = _ImportWorker(self.session, config)
self._worker.finished.connect(self._on_import_finished)
self._worker.start()
def _on_import_finished(self, results: dict):
self._results_page.show_results(results)
self.import_complete.emit()
self._worker = None

View File

@@ -1,7 +1,13 @@
from pathlib import Path
from PySide6.QtWidgets import QMainWindow, QHBoxLayout, QWidget, QStackedWidget, QLabel
from PySide6.QtCore import Qt
from src.ui.sidebar import Sidebar
from src.ui.import_view import ImportView
from src.ui.transactions_view import TransactionsView
from src.ui.analysis_view import AnalysisView
from src.ui.recurring_view import RecurringView
from src.ui.settings_view import SettingsView
class MainWindow(QMainWindow):
@@ -25,17 +31,42 @@ class MainWindow(QMainWindow):
# Build views — real implementations where available, placeholders otherwise
self._views = {}
self._import_view = ImportView(session)
self._transactions_view = TransactionsView(session)
self._analysis_view = AnalysisView(session)
self._recurring_view = RecurringView(session)
self._settings_view = SettingsView(session)
for label, key in Sidebar.VIEWS:
if key == "transactions":
if key == "import":
self._views[key] = self.stack.addWidget(self._import_view)
elif key == "transactions":
self._views[key] = self.stack.addWidget(self._transactions_view)
elif key == "analysis":
self._views[key] = self.stack.addWidget(self._analysis_view)
elif key == "recurring":
self._views[key] = self.stack.addWidget(self._recurring_view)
elif key == "settings":
self._views[key] = self.stack.addWidget(self._settings_view)
else:
placeholder = QLabel(f"{label} View")
placeholder.setAlignment(Qt.AlignmentFlag.AlignCenter)
placeholder.setStyleSheet("font-size: 24px; color: #888;")
self._views[key] = self.stack.addWidget(placeholder)
# Wire up ImportView signals
self._import_view.import_complete.connect(self._transactions_view.refresh)
self._import_view.navigate_to.connect(self._switch_view)
self.sidebar.view_changed.connect(self._switch_view)
self.sidebar.theme_toggled.connect(self._apply_theme)
# Apply dark theme by default
self._apply_theme("dark")
def _switch_view(self, key: str):
self.stack.setCurrentIndex(self._views[key])
def _apply_theme(self, theme: str):
qss_path = Path(__file__).parent / "themes" / f"{theme}.qss"
if qss_path.exists():
self.setStyleSheet(qss_path.read_text(encoding="utf-8"))