Skip to content

Infrastructure Layer

The infrastructure layer implements the repository interfaces defined in the domain layer. It is the only layer that touches DuckDB, pickle files, or external HTTP calls.

DuckDB Adapter

src.infrastructure.db.duckdb_adapter

DuckDB adapter – infrastructure layer database connection.

All SQL lives here. Domain layer remains pure Python.

Functions

get_connection
get_connection(read_only: bool = True) -> Generator[duckdb.DuckDBPyConnection, None, None]

Context manager for a DuckDB connection.

Parameters:

Name Type Description Default
read_only bool

If True (default), opens a read-only connection safe for concurrent API workers. Set False for write operations.

True

Yields:

Type Description
DuckDBPyConnection

An active DuckDB connection that is closed on exit.

Source code in src/infrastructure/db/duckdb_adapter.py
@contextmanager
def get_connection(read_only: bool = True) -> Generator[duckdb.DuckDBPyConnection, None, None]:
    """Context manager for a DuckDB connection.

    Args:
        read_only: If True (default), opens a read-only connection safe for
                   concurrent API workers. Set False for write operations.

    Yields:
        An active DuckDB connection that is closed on exit.
    """
    conn = duckdb.connect(database=_DB_PATH, read_only=read_only)
    logger.debug("duckdb.connection.opened", path=_DB_PATH, read_only=read_only)
    try:
        yield conn
    finally:
        conn.close()
        logger.debug("duckdb.connection.closed")

Customer Repository (DuckDB)

src.infrastructure.repositories.customer_repository

DuckDB implementation of CustomerRepository.

Classes

DuckDBCustomerRepository

Bases: CustomerRepository

Reads Customer entities from the DuckDB warehouse.

Source code in src/infrastructure/repositories/customer_repository.py
class DuckDBCustomerRepository(CustomerRepository):
    """Reads Customer entities from the DuckDB warehouse."""

    def get_by_id(self, customer_id: str) -> Customer | None:
        """Fetch a single customer by ID."""
        with get_connection() as conn:
            row = conn.execute(
                """
                SELECT customer_id, industry, plan_tier, signup_date, mrr, churn_date
                FROM raw.customers
                WHERE customer_id = ?
                """,
                [customer_id],
            ).fetchone()

        if row is None:
            return None
        return self._row_to_entity(row)

    def get_all_active(self) -> Sequence[Customer]:
        """Return all customers without a churn_date."""
        with get_connection() as conn:
            rows = conn.execute(
                """
                SELECT customer_id, industry, plan_tier, signup_date, mrr, churn_date
                FROM raw.customers
                WHERE churn_date IS NULL
                ORDER BY mrr DESC
                """
            ).fetchall()
        return [self._row_to_entity(row) for row in rows]

    def get_sample(self, n: int) -> Sequence[Customer]:
        """Return a random sample of n customers using DuckDB reservoir sampling.

        DuckDB's SAMPLE clause does not support parameterized queries — only
        literal integer constants are accepted. n is safe to interpolate because
        callers must clamp it to ≤ 100 before passing it in.
        """
        with get_connection() as conn:
            rows = conn.execute(
                f"""
                SELECT customer_id, industry, plan_tier, signup_date, mrr, churn_date
                FROM raw.customers
                USING SAMPLE reservoir({n} ROWS) REPEATABLE(42)
                """
            ).fetchall()
        return [self._row_to_entity(row) for row in rows]

    def save(self, customer: Customer) -> None:
        """Upsert a customer record."""
        with get_connection(read_only=False) as conn:
            conn.execute(
                """
                INSERT OR REPLACE INTO raw.customers
                    (customer_id, industry, plan_tier, signup_date, mrr, churn_date)
                VALUES (?, ?, ?, ?, ?, ?)
                """,
                [
                    customer.customer_id,
                    customer.industry.value,
                    customer.plan_tier.value,
                    customer.signup_date,
                    float(customer.mrr.amount),
                    customer.churn_date,
                ],
            )

    @staticmethod
    def _row_to_entity(row: tuple) -> Customer:  # type: ignore[type-arg]
        customer_id, industry, plan_tier, signup_date, mrr, churn_date = row
        return Customer(
            customer_id=str(customer_id),
            industry=Industry(industry),
            plan_tier=PlanTier(plan_tier),
            signup_date=date.fromisoformat(str(signup_date)),
            mrr=MRR(amount=Decimal(str(mrr))),
            churn_date=date.fromisoformat(str(churn_date)) if churn_date else None,
        )
Functions
get_by_id
get_by_id(customer_id: str) -> Customer | None

Fetch a single customer by ID.

Source code in src/infrastructure/repositories/customer_repository.py
def get_by_id(self, customer_id: str) -> Customer | None:
    """Fetch a single customer by ID."""
    with get_connection() as conn:
        row = conn.execute(
            """
            SELECT customer_id, industry, plan_tier, signup_date, mrr, churn_date
            FROM raw.customers
            WHERE customer_id = ?
            """,
            [customer_id],
        ).fetchone()

    if row is None:
        return None
    return self._row_to_entity(row)
get_all_active
get_all_active() -> Sequence[Customer]

Return all customers without a churn_date.

Source code in src/infrastructure/repositories/customer_repository.py
def get_all_active(self) -> Sequence[Customer]:
    """Return all customers without a churn_date."""
    with get_connection() as conn:
        rows = conn.execute(
            """
            SELECT customer_id, industry, plan_tier, signup_date, mrr, churn_date
            FROM raw.customers
            WHERE churn_date IS NULL
            ORDER BY mrr DESC
            """
        ).fetchall()
    return [self._row_to_entity(row) for row in rows]
get_sample
get_sample(n: int) -> Sequence[Customer]

Return a random sample of n customers using DuckDB reservoir sampling.

DuckDB's SAMPLE clause does not support parameterized queries — only literal integer constants are accepted. n is safe to interpolate because callers must clamp it to ≤ 100 before passing it in.

Source code in src/infrastructure/repositories/customer_repository.py
def get_sample(self, n: int) -> Sequence[Customer]:
    """Return a random sample of n customers using DuckDB reservoir sampling.

    DuckDB's SAMPLE clause does not support parameterized queries — only
    literal integer constants are accepted. n is safe to interpolate because
    callers must clamp it to ≤ 100 before passing it in.
    """
    with get_connection() as conn:
        rows = conn.execute(
            f"""
            SELECT customer_id, industry, plan_tier, signup_date, mrr, churn_date
            FROM raw.customers
            USING SAMPLE reservoir({n} ROWS) REPEATABLE(42)
            """
        ).fetchall()
    return [self._row_to_entity(row) for row in rows]
save
save(customer: Customer) -> None

Upsert a customer record.

Source code in src/infrastructure/repositories/customer_repository.py
def save(self, customer: Customer) -> None:
    """Upsert a customer record."""
    with get_connection(read_only=False) as conn:
        conn.execute(
            """
            INSERT OR REPLACE INTO raw.customers
                (customer_id, industry, plan_tier, signup_date, mrr, churn_date)
            VALUES (?, ?, ?, ?, ?, ?)
            """,
            [
                customer.customer_id,
                customer.industry.value,
                customer.plan_tier.value,
                customer.signup_date,
                float(customer.mrr.amount),
                customer.churn_date,
            ],
        )

Functions

Usage Repository (DuckDB)

src.infrastructure.repositories.usage_repository

DuckDB implementation of UsageRepository.

Classes

DuckDBUsageRepository

Bases: UsageRepository

Reads UsageEvent entities from the DuckDB warehouse.

Source code in src/infrastructure/repositories/usage_repository.py
class DuckDBUsageRepository(UsageRepository):
    """Reads UsageEvent entities from the DuckDB warehouse."""

    def get_events_for_customer(
        self,
        customer_id: str,
        since: datetime | None = None,
    ) -> Sequence[UsageEvent]:
        since_clause = "AND timestamp >= ?" if since else ""
        params = [customer_id] + ([since] if since else [])

        with get_connection() as conn:
            rows = conn.execute(
                f"""
                SELECT event_id, customer_id, timestamp, event_type, feature_adoption_score
                FROM raw.usage_events
                WHERE customer_id = ? {since_clause}
                ORDER BY timestamp DESC
                """,
                params,
            ).fetchall()

        return [self._row_to_entity(row) for row in rows]

    def get_event_count_last_n_days(self, customer_id: str, days: int) -> int:
        with get_connection() as conn:
            result = conn.execute(
                """
                SELECT COUNT(*) FROM raw.usage_events
                WHERE customer_id = ?
                  AND timestamp >= CURRENT_TIMESTAMP - INTERVAL (?) DAY
                """,
                [customer_id, days],
            ).fetchone()
        return int(result[0]) if result else 0

    @staticmethod
    def _row_to_entity(row: tuple) -> UsageEvent:  # type: ignore[type-arg]
        event_id, customer_id, timestamp, event_type, adoption_score = row
        return UsageEvent(
            event_id=str(event_id),
            customer_id=str(customer_id),
            timestamp=timestamp if isinstance(timestamp, datetime) else datetime.fromisoformat(str(timestamp)),
            event_type=EventType(event_type),
            feature_adoption_score=FeatureAdoptionScore(value=float(adoption_score)),
        )

Functions

Model Registry

src.infrastructure.ml.model_registry

Model registry – loads and caches DVC-tracked model artifacts.

Model files (.pkl) are versioned via DVC and stored in models/. This module is the only place in the codebase that touches pickle files.

Functions

load_model cached
load_model(name: str) -> Any

Load a pickled model artifact by name (cached after first load).

Parameters:

Name Type Description Default
name str

Model name without extension, e.g. "churn_model" or "risk_model".

required

Returns:

Type Description
Any

The deserialized model object.

Raises:

Type Description
FileNotFoundError

If the artifact does not exist. Run dvc pull first.

Source code in src/infrastructure/ml/model_registry.py
@lru_cache(maxsize=4)
def load_model(name: str) -> Any:  # noqa: ANN401 — generic model loader, return type depends on artifact
    """Load a pickled model artifact by name (cached after first load).

    Args:
        name: Model name without extension, e.g. "churn_model" or "risk_model".

    Returns:
        The deserialized model object.

    Raises:
        FileNotFoundError: If the artifact does not exist. Run `dvc pull` first.
    """
    path = _MODELS_DIR / f"{name}.pkl"
    if not path.exists():
        raise FileNotFoundError(f"Model artifact not found at {path}. Run `dvc pull` to fetch versioned artifacts.")
    logger.info("model.loaded", name=name, path=str(path))
    with open(path, "rb") as f:
        return pickle.load(f)  # noqa: S301
get_model_metadata
get_model_metadata(name: str) -> dict[str, Any]

Return metadata (version, training date, metrics) for a model artifact.

Source code in src/infrastructure/ml/model_registry.py
def get_model_metadata(name: str) -> dict[str, Any]:
    """Return metadata (version, training date, metrics) for a model artifact."""
    meta_path = _MODELS_DIR / f"{name}_metadata.json"
    if not meta_path.exists():
        return {"version": "unknown"}
    with open(meta_path) as f:
        return json.load(f)  # type: ignore[no-any-return]