Skip to content

Prediction Domain

The prediction_domain bounded context contains the churn model, risk scoring, and SHAP explanation logic. It consumes entities from customer_domain and usage_domain but has no knowledge of infrastructure.

Entities

src.domain.prediction.entities

PredictionResult entity – output of the Prediction domain services.

Classes

ShapFeature dataclass

A single SHAP feature contribution to a prediction.

Provides model explainability for CS teams to understand why a customer was flagged, enabling targeted interventions.

Source code in src/domain/prediction/entities.py
@dataclass
class ShapFeature:
    """A single SHAP feature contribution to a prediction.

    Provides model explainability for CS teams to understand why a
    customer was flagged, enabling targeted interventions.
    """

    feature_name: str
    feature_value: float
    shap_impact: float  # Positive = increases churn risk
PredictionResult dataclass

The complete output of a churn + risk prediction for one customer.

Parameters:

Name Type Description Default
customer_id str

The customer this prediction belongs to.

required
churn_probability ChurnProbability

Calibrated P(churn in 90 days).

required
risk_score RiskScore

Composite compliance/usage risk score.

required
top_shap_features list[ShapFeature]

Top-N SHAP drivers (sorted by |shap_impact|).

list()
model_version str

Semantic version of the model artifact used.

'0.0.0'
predicted_at datetime

UTC timestamp of when the prediction was generated.

(lambda: now(UTC))()
Source code in src/domain/prediction/entities.py
@dataclass
class PredictionResult:
    """The complete output of a churn + risk prediction for one customer.

    Args:
        customer_id: The customer this prediction belongs to.
        churn_probability: Calibrated P(churn in 90 days).
        risk_score: Composite compliance/usage risk score.
        top_shap_features: Top-N SHAP drivers (sorted by |shap_impact|).
        model_version: Semantic version of the model artifact used.
        predicted_at: UTC timestamp of when the prediction was generated.
    """

    customer_id: str
    churn_probability: ChurnProbability
    risk_score: RiskScore
    top_shap_features: list[ShapFeature] = field(default_factory=list)
    model_version: str = "0.0.0"
    predicted_at: datetime = field(default_factory=lambda: datetime.now(UTC))

    @property
    def recommended_action(self) -> str:
        """Natural-language CS recommendation based on prediction outputs.

        This is a deterministic rule — LLM summaries build on top of this
        in the AI/LLM layer (Phase 5).
        """
        if self.churn_probability.value >= 0.75:
            return "CRITICAL – Escalate to senior CSM immediately. Schedule EBR within 7 days."
        if self.churn_probability.value >= 0.5:
            return "HIGH RISK – Trigger CS outreach within 48 hours. Review top SHAP drivers."
        if self.churn_probability.value >= 0.25:
            return "MEDIUM RISK – Add to CSM watch list. Schedule check-in call."
        return "LOW RISK – No immediate action required. Monitor monthly."
Attributes
recommended_action property
recommended_action: str

Natural-language CS recommendation based on prediction outputs.

This is a deterministic rule — LLM summaries build on top of this in the AI/LLM layer (Phase 5).

Value Objects

src.domain.prediction.value_objects

Value objects for the Prediction domain.

Classes

ChurnProbability dataclass

P(churn in next 90 days) output from the churn model.

Calibrated probability in [0, 1]. The 0.5 threshold is the default operating point; business impact analysis should inform the actual threshold used for CS outreach triggers.

Source code in src/domain/prediction/value_objects.py
@dataclass(frozen=True)
class ChurnProbability:
    """P(churn in next 90 days) output from the churn model.

    Calibrated probability in [0, 1]. The 0.5 threshold is the default
    operating point; business impact analysis should inform the actual
    threshold used for CS outreach triggers.
    """

    value: float

    def __post_init__(self) -> None:
        if not (0.0 <= self.value <= 1.0):
            raise ValueError(f"ChurnProbability must be in [0, 1], got {self.value}")

    @property
    def risk_tier(self) -> RiskTier:
        if self.value >= 0.75:
            return RiskTier.CRITICAL
        if self.value >= 0.5:
            return RiskTier.HIGH
        if self.value >= 0.25:
            return RiskTier.MEDIUM
        return RiskTier.LOW

    @property
    def requires_immediate_action(self) -> bool:
        """True if CS outreach should be triggered within 48 hours."""
        return self.value >= 0.5
Attributes
requires_immediate_action property
requires_immediate_action: bool

True if CS outreach should be triggered within 48 hours.

RiskScore dataclass

Composite compliance + usage risk score in [0, 1].

Combines compliance_gap_score, vendor_risk_flags, and usage_decay_score. Distinct from churn probability — a customer can have high risk score but low churn probability if they are contractually locked in.

Source code in src/domain/prediction/value_objects.py
@dataclass(frozen=True)
class RiskScore:
    """Composite compliance + usage risk score in [0, 1].

    Combines compliance_gap_score, vendor_risk_flags, and usage_decay_score.
    Distinct from churn probability — a customer can have high risk score
    but low churn probability if they are contractually locked in.
    """

    value: float

    def __post_init__(self) -> None:
        if not (0.0 <= self.value <= 1.0):
            raise ValueError(f"RiskScore must be in [0, 1], got {self.value}")

    @property
    def tier(self) -> RiskTier:
        if self.value >= 0.75:
            return RiskTier.CRITICAL
        if self.value >= 0.5:
            return RiskTier.HIGH
        if self.value >= 0.25:
            return RiskTier.MEDIUM
        return RiskTier.LOW

Churn Model Service

src.domain.prediction.churn_model_service

ChurnModelService – domain service for churn probability prediction.

Domain services encapsulate operations that don't naturally belong to a single entity. The model artifact is injected as a dependency (no direct file I/O here).

Classes

ChurnFeatureVector

Bases: Protocol

Protocol for feature extraction – implemented in infrastructure layer.

Phase 4 update: the extractor queries the dbt mart directly (single DuckDB read), so events no longer need to be passed from the use case layer. This keeps the protocol minimal and moves feature logic into dbt.

Source code in src/domain/prediction/churn_model_service.py
class ChurnFeatureVector(Protocol):
    """Protocol for feature extraction – implemented in infrastructure layer.

    Phase 4 update: the extractor queries the dbt mart directly (single DuckDB
    read), so events no longer need to be passed from the use case layer.
    This keeps the protocol minimal and moves feature logic into dbt.
    """

    def extract(self, customer: Customer) -> dict[str, float | str]:
        """Extract the model's feature vector for a customer.

        Args:
            customer: Active Customer entity (used to look up mart row by ID).

        Returns:
            Flat dict of feature_name → value (numerics as float, categoricals
            as lowercase string for sklearn OrdinalEncoder compatibility).
            All feature engineering lives in mart_customer_churn_features.

        Raises:
            ValueError: If the customer is not found in the mart (e.g. churned
                        customers are excluded from the mart).
        """
        ...
Functions
extract
extract(customer: Customer) -> dict[str, float | str]

Extract the model's feature vector for a customer.

Parameters:

Name Type Description Default
customer Customer

Active Customer entity (used to look up mart row by ID).

required

Returns:

Type Description
dict[str, float | str]

Flat dict of feature_name → value (numerics as float, categoricals

dict[str, float | str]

as lowercase string for sklearn OrdinalEncoder compatibility).

dict[str, float | str]

All feature engineering lives in mart_customer_churn_features.

Raises:

Type Description
ValueError

If the customer is not found in the mart (e.g. churned customers are excluded from the mart).

Source code in src/domain/prediction/churn_model_service.py
def extract(self, customer: Customer) -> dict[str, float | str]:
    """Extract the model's feature vector for a customer.

    Args:
        customer: Active Customer entity (used to look up mart row by ID).

    Returns:
        Flat dict of feature_name → value (numerics as float, categoricals
        as lowercase string for sklearn OrdinalEncoder compatibility).
        All feature engineering lives in mart_customer_churn_features.

    Raises:
        ValueError: If the customer is not found in the mart (e.g. churned
                    customers are excluded from the mart).
    """
    ...
ChurnModelPort

Bases: ABC

Abstract port for the underlying ML model.

Concrete implementations in src/infrastructure/ml/ load the trained XGBoost/survival model artifact.

Source code in src/domain/prediction/churn_model_service.py
class ChurnModelPort(ABC):
    """Abstract port for the underlying ML model.

    Concrete implementations in src/infrastructure/ml/
    load the trained XGBoost/survival model artifact.
    """

    @abstractmethod
    def predict_proba(self, features: dict[str, float | str]) -> float:
        """Return P(churn in 90 days) for the given feature vector."""
        ...

    @abstractmethod
    def explain(self, features: dict[str, float | str]) -> list[ShapFeature]:
        """Return SHAP feature contributions for explainability."""
        ...

    @property
    @abstractmethod
    def version(self) -> str:
        """Semantic version of the loaded model artifact."""
        ...
Attributes
version abstractmethod property
version: str

Semantic version of the loaded model artifact.

Functions
predict_proba abstractmethod
predict_proba(features: dict[str, float | str]) -> float

Return P(churn in 90 days) for the given feature vector.

Source code in src/domain/prediction/churn_model_service.py
@abstractmethod
def predict_proba(self, features: dict[str, float | str]) -> float:
    """Return P(churn in 90 days) for the given feature vector."""
    ...
explain abstractmethod
explain(features: dict[str, float | str]) -> list[ShapFeature]

Return SHAP feature contributions for explainability.

Source code in src/domain/prediction/churn_model_service.py
@abstractmethod
def explain(self, features: dict[str, float | str]) -> list[ShapFeature]:
    """Return SHAP feature contributions for explainability."""
    ...
ChurnModelService

Orchestrates feature extraction → model inference → result construction.

Parameters:

Name Type Description Default
model ChurnModelPort

Concrete ML model (injected from infrastructure layer).

required
feature_extractor ChurnFeatureVector

Queries the dbt mart for the customer's feature vector.

required
Source code in src/domain/prediction/churn_model_service.py
class ChurnModelService:
    """Orchestrates feature extraction → model inference → result construction.

    Args:
        model: Concrete ML model (injected from infrastructure layer).
        feature_extractor: Queries the dbt mart for the customer's feature vector.
    """

    def __init__(self, model: ChurnModelPort, feature_extractor: ChurnFeatureVector) -> None:
        self._model = model
        self._feature_extractor = feature_extractor

    def predict(
        self,
        customer: Customer,
        risk_score: RiskScore,
    ) -> PredictionResult:
        """Generate a full PredictionResult for a customer.

        Business Context: Feature extraction, model inference, and SHAP
        computation are all delegated to injected dependencies. This service
        only owns the assembly logic, keeping it testable in isolation.

        Args:
            customer: Active Customer entity.
            risk_score: Pre-computed composite risk score (from RiskModelService).

        Returns:
            PredictionResult with calibrated churn probability, SHAP explanations,
            and a deterministic recommended CS action.
        """
        features = self._feature_extractor.extract(customer)
        churn_prob = self._model.predict_proba(features)
        shap_features = self._model.explain(features)

        return PredictionResult(
            customer_id=customer.customer_id,
            churn_probability=ChurnProbability(value=churn_prob),
            risk_score=risk_score,
            top_shap_features=sorted(shap_features, key=lambda f: abs(f.shap_impact), reverse=True)[:5],
            model_version=self._model.version,
        )
Functions
predict
predict(customer: Customer, risk_score: RiskScore) -> PredictionResult

Generate a full PredictionResult for a customer.

Business Context: Feature extraction, model inference, and SHAP computation are all delegated to injected dependencies. This service only owns the assembly logic, keeping it testable in isolation.

Parameters:

Name Type Description Default
customer Customer

Active Customer entity.

required
risk_score RiskScore

Pre-computed composite risk score (from RiskModelService).

required

Returns:

Type Description
PredictionResult

PredictionResult with calibrated churn probability, SHAP explanations,

PredictionResult

and a deterministic recommended CS action.

Source code in src/domain/prediction/churn_model_service.py
def predict(
    self,
    customer: Customer,
    risk_score: RiskScore,
) -> PredictionResult:
    """Generate a full PredictionResult for a customer.

    Business Context: Feature extraction, model inference, and SHAP
    computation are all delegated to injected dependencies. This service
    only owns the assembly logic, keeping it testable in isolation.

    Args:
        customer: Active Customer entity.
        risk_score: Pre-computed composite risk score (from RiskModelService).

    Returns:
        PredictionResult with calibrated churn probability, SHAP explanations,
        and a deterministic recommended CS action.
    """
    features = self._feature_extractor.extract(customer)
    churn_prob = self._model.predict_proba(features)
    shap_features = self._model.explain(features)

    return PredictionResult(
        customer_id=customer.customer_id,
        churn_probability=ChurnProbability(value=churn_prob),
        risk_score=risk_score,
        top_shap_features=sorted(shap_features, key=lambda f: abs(f.shap_impact), reverse=True)[:5],
        model_version=self._model.version,
    )

Risk Model Service

src.domain.prediction.risk_model_service

RiskModelService – domain service for compliance + usage risk scoring.

Classes

RiskSignals dataclass

Raw risk inputs from the risk_signals table.

Parameters:

Name Type Description Default
compliance_gap_score float

0–1 score of open compliance gaps.

required
vendor_risk_flags int

Count of third-party vendor risk alerts.

required
usage_decay_score float

0–1 score of recent usage decline (computed from events).

required
Source code in src/domain/prediction/risk_model_service.py
@dataclass(frozen=True)
class RiskSignals:
    """Raw risk inputs from the risk_signals table.

    Args:
        compliance_gap_score: 0–1 score of open compliance gaps.
        vendor_risk_flags: Count of third-party vendor risk alerts.
        usage_decay_score: 0–1 score of recent usage decline (computed from events).
    """

    compliance_gap_score: float
    vendor_risk_flags: int
    usage_decay_score: float
RiskModelService

Computes a composite risk score from compliance and usage signals.

Weights are calibrated to business impact: - Usage decay is the strongest leading indicator of near-term churn - Compliance gaps drive risk but not always churn (contractual stickiness) - Vendor risk flags have lower weight but non-zero contribution

These weights should be revisited quarterly using SHAP analysis on the full churn model to ensure they remain calibrated to observed outcomes.

Source code in src/domain/prediction/risk_model_service.py
class RiskModelService:
    """Computes a composite risk score from compliance and usage signals.

    Weights are calibrated to business impact:
    - Usage decay is the strongest leading indicator of near-term churn
    - Compliance gaps drive risk but not always churn (contractual stickiness)
    - Vendor risk flags have lower weight but non-zero contribution

    These weights should be revisited quarterly using SHAP analysis on
    the full churn model to ensure they remain calibrated to observed outcomes.
    """

    USAGE_WEIGHT: float = 0.50
    COMPLIANCE_WEIGHT: float = 0.35
    VENDOR_WEIGHT: float = 0.15
    VENDOR_FLAG_NORMALISER: float = 5.0  # treat 5+ flags as max risk

    def compute(self, signals: RiskSignals) -> RiskScore:
        """Compute a composite RiskScore from raw signals.

        Args:
            signals: The three risk signal components.

        Returns:
            RiskScore value object in [0, 1].
        """
        vendor_normalised = min(signals.vendor_risk_flags / self.VENDOR_FLAG_NORMALISER, 1.0)
        composite = (
            self.USAGE_WEIGHT * signals.usage_decay_score
            + self.COMPLIANCE_WEIGHT * signals.compliance_gap_score
            + self.VENDOR_WEIGHT * vendor_normalised
        )
        return RiskScore(value=round(composite, 4))
Functions
compute
compute(signals: RiskSignals) -> RiskScore

Compute a composite RiskScore from raw signals.

Parameters:

Name Type Description Default
signals RiskSignals

The three risk signal components.

required

Returns:

Type Description
RiskScore

RiskScore value object in [0, 1].

Source code in src/domain/prediction/risk_model_service.py
def compute(self, signals: RiskSignals) -> RiskScore:
    """Compute a composite RiskScore from raw signals.

    Args:
        signals: The three risk signal components.

    Returns:
        RiskScore value object in [0, 1].
    """
    vendor_normalised = min(signals.vendor_risk_flags / self.VENDOR_FLAG_NORMALISER, 1.0)
    composite = (
        self.USAGE_WEIGHT * signals.usage_decay_score
        + self.COMPLIANCE_WEIGHT * signals.compliance_gap_score
        + self.VENDOR_WEIGHT * vendor_normalised
    )
    return RiskScore(value=round(composite, 4))