Source code for kkf.solution
"""Solution data structure for Koopman Kalman Filter results.
Provides the KoopmanKalmanFilterSolution dataclass for storing and accessing
filter outputs.
"""
from dataclasses import dataclass
from typing import Dict
import numpy as np
from numpy.typing import NDArray
[docs]
@dataclass
class KoopmanKalmanFilterSolution:
"""
A class to store the solution of a Koopman-Kalman Filter iteration.
This class maintains the state estimates, covariance matrices, and filter gains
for both the prior (minus) and posterior (plus) estimates in both state and
feature spaces.
Attributes
----------
x_plus : np.ndarray
Posterior state estimate after measurement update.
x_minus : np.ndarray
Prior state estimate from prediction step.
Pz_plus : np.ndarray
Posterior covariance matrix in feature space after measurement update.
Pz_minus : np.ndarray
Prior covariance matrix in feature space from prediction step.
Px_plus : np.ndarray
Posterior covariance matrix in state space after measurement update.
Px_minus : np.ndarray
Prior covariance matrix in state space from prediction step.
S : np.ndarray
Innovation (residual) covariance matrix.
K : np.ndarray
Kalman gain matrix.
Notes
-----
The class uses the common Kalman filter notation where:
- (-) denotes prior estimates before measurement update
- (+) denotes posterior estimates after measurement update
- Pz refers to covariances in the feature/transformed space
- Px refers to covariances in the original state space
Examples
--------
>>> solution = KoopmanKalmanFilterSolution(
... x_plus=np.array([1.0, 2.0]),
... x_minus=np.array([0.9, 1.9]),
... Pz_plus=np.eye(2),
... Pz_minus=np.eye(2) * 1.1,
... Px_plus=np.eye(2) * 0.9,
... Px_minus=np.eye(2),
... S=np.eye(2) * 0.5,
... K=np.array([[0.1, 0], [0, 0.1]])
... )
"""
x_plus: NDArray[np.float64]
x_minus: NDArray[np.float64]
Pz_plus: NDArray[np.float64]
Pz_minus: NDArray[np.float64]
Px_plus: NDArray[np.float64]
Px_minus: NDArray[np.float64]
S: NDArray[np.float64]
K: NDArray[np.float64]
def __post_init__(self) -> None:
"""Validate the dimensions of the input arrays."""
# Ensure all inputs are numpy arrays
for attr in ["x_plus", "x_minus", "Pz_plus", "Pz_minus", "Px_plus", "Px_minus", "S", "K"]:
value = getattr(self, attr)
if not isinstance(value, np.ndarray):
setattr(self, attr, np.array(value))
[docs]
def get_state_dimension(self) -> int:
"""
Get the dimension of the state vector.
Returns
-------
int
The dimension of the state vector.
"""
return self.x_plus.shape[-1]
[docs]
def get_feature_dimension(self) -> int:
"""
Get the dimension of the feature space.
Returns
-------
int
The dimension of the feature space.
"""
return self.Pz_plus.shape[-1]
[docs]
def get_estimation_error(self) -> NDArray[np.float64]:
"""
Calculate the difference between prior and posterior estimates.
Returns
-------
np.ndarray
The difference between posterior and prior state estimates.
"""
return self.x_plus - self.x_minus
[docs]
def get_trace_reduction(self) -> float:
"""
Calculate the reduction in uncertainty as measured by trace of covariance.
Returns
-------
float
The relative reduction in trace of the state covariance matrix.
"""
trace_minus = np.trace(self.Px_minus)
trace_plus = np.trace(self.Px_plus)
return (trace_minus - trace_plus) / trace_minus if trace_minus != 0 else 0.0
[docs]
def to_dict(self) -> Dict[str, NDArray[np.float64]]:
"""
Convert the solution to a dictionary format.
Returns
-------
dict
Dictionary containing all solution components.
"""
return {
"x_plus": self.x_plus,
"x_minus": self.x_minus,
"Pz_plus": self.Pz_plus,
"Pz_minus": self.Pz_minus,
"Px_plus": self.Px_plus,
"Px_minus": self.Px_minus,
"S": self.S,
"K": self.K,
}