Classifying Ships with Gaussian Process Mixtures

machine_learning
Published

November 19, 2025

using DataFrames, PythonCall, CSV, Dates, DataFramesMeta, RecursiveArrayTools, StatsBase, CairoMakie, Distributions, SpecialFunctions, PDMats, LinearAlgebra, SizeCheck, KernelFunctions, AbstractGPs, LogExpFunctions, Random
CairoMakie.enable_only_mime!("png")

I recently came across a dataset of container ship movement between Tallinn and Helsinki on Kaggle. In this notebook, we’ll try to classify whether a given ship’s trajectory seems similar to those of the container ships, or whether we’re looking at something else (perhaps a pirate).

I’ll start by loading the ship tracking data, taking zscores of the latitude and longitude, and normalizing timing information for each trajectory. We can identify rows belonging to the same ship’s trajectory by the ship’s International Maritime Organization code (IMO) and its actual arrival time (ATA). For simplicity, we’ll just consider 2k samples of trajectories leaving Helsinki for the time being.

function load_ship_data()
    kagglehub = pyimport("kagglehub")
    ships_path = pyconvert(String, kagglehub.dataset_download("bobaaayoung/container-ship-data-collection"))
    raw_df = dropmissing(
        CSV.read(joinpath(ships_path, "tracking_db.csv"), DataFrame,
            dateformat="mm/dd/yyyy HH:MM",
            types=Dict(:updated => DateTime, :ata => DateTime), silencewarnings=true))

    df = @chain raw_df begin
        @subset(:long .> 24)
        @subset(:arrPort .== "FIHEL")
        @select(:updated, :ata, :long, :lat, :imo)
        @transform(:t = Minute.(:ata .- :updated), :long = zscore(:long), :lat = zscore(:lat))
        @groupby :imo :ata
        @transform(:nt = (:t .- minimum(:t)) ./ (maximum(:t) - minimum(:t)))
        @subset((:nt .> 0) .& (:nt .< 1))
    end

    df, groupby(df, [:imo, :ata])
end
load_ship_data (generic function with 1 method)
df, groups = load_ship_data()
(94944×7 DataFrame
   Row  updated              ata                  long       lat       imo           DateTime             DateTime             Float64    Float64   Int64  ⋯
───────┼────────────────────────────────────────────────────────────────────────
     1 │ 2018-05-04T19:26:00  2018-04-05T21:26:00  -0.545055  -1.52547  936472 ⋯
     2 │ 2018-05-04T19:29:00  2018-04-05T21:24:00  -0.615397  -1.47017  936472
     3 │ 2018-05-04T19:33:00  2018-04-05T21:28:00  -0.806885  -1.34498  936472
     4 │ 2018-05-04T19:36:00  2018-04-05T21:26:00  -0.918912  -1.28835  936472
     5 │ 2018-05-04T19:37:00  2018-04-05T21:31:00  -0.974925  -1.25871  936472 ⋯
     6 │ 2018-05-04T19:39:00  2018-04-05T21:30:00  -1.09086   -1.20032  936472
     7 │ 2018-05-04T19:39:00  2018-04-05T21:26:00  -1.10128   -1.19501  936472
     8 │ 2018-05-04T19:41:00  2018-04-05T21:26:00  -1.22112   -1.13661  936472
   ⋮   │          ⋮                    ⋮               ⋮         ⋮         ⋮   ⋱
 94938 │ 2019-12-03T20:06:00  2019-03-12T20:35:00   1.06891    1.07842  936472 ⋯
 94939 │ 2019-12-03T20:07:00  2019-03-12T20:29:00   1.03374    1.1054   936472
 94940 │ 2019-12-03T20:12:00  2019-03-12T20:28:00   1.05458    1.24564  936472
 94941 │ 2019-12-03T20:14:00  2019-03-12T20:26:00   1.09106    1.29386  936472
 94942 │ 2019-12-03T20:16:00  2019-03-12T20:26:00   1.14447    1.33854  936472 ⋯
 94943 │ 2019-12-03T20:19:00  2019-03-12T20:24:00   1.1627     1.3965   936472
 94944 │ 2019-12-03T20:23:00  2019-03-12T20:22:00   1.15358    1.48011  936472
                                                3 columns and 94929 rows omitted, GroupedDataFrame with 18432 groups based on keys: imo, ata
First Group (14 rows): imo = 9364722, ata = 2018-04-05T21:26:00
 Row  updated              ata                  long       lat        imo          DateTime             DateTime             Float64    Float64    Int64   ⋯
─────┼──────────────────────────────────────────────────────────────────────────
   1 │ 2018-05-04T19:26:00  2018-04-05T21:26:00  -0.545055  -1.52547   9364722 ⋯
   2 │ 2018-05-04T19:36:00  2018-04-05T21:26:00  -0.918912  -1.28835   9364722
   3 │ 2018-05-04T19:39:00  2018-04-05T21:26:00  -1.10128   -1.19501   9364722
   4 │ 2018-05-04T19:41:00  2018-04-05T21:26:00  -1.22112   -1.13661   9364722
  ⋮  │          ⋮                    ⋮               ⋮          ⋮         ⋮    ⋱
  11 │ 2018-05-04T20:40:00  2018-04-05T21:26:00   0.945162   0.413511  9364722 ⋯
  12 │ 2018-05-04T20:47:00  2018-04-05T21:26:00   0.955583   0.609489  9364722
  13 │ 2018-05-04T20:49:00  2018-04-05T21:26:00   0.984241   0.669211  9364722
  14 │ 2018-05-04T20:55:00  2018-04-05T21:26:00   1.10539    0.856341  9364722
                                                    2 columns and 6 rows omitted
⋮
Last Group (1 row): imo = 9364722, ata = 2019-03-12T20:22:00
 Row  updated              ata                  long     lat      imo      t       DateTime             DateTime             Float64  Float64  Int64    Mi ⋯
─────┼──────────────────────────────────────────────────────────────────────────
   1 │ 2019-12-03T20:23:00  2019-03-12T20:22:00  1.15358  1.48011  9364722  -3 ⋯
                                                               2 columns omitted)

Below, I’ve plotted the trajectories with color marking the passage of time.

scatter(df[!, :long], df[!, :lat], color=df[!, :nt], markersize=2, alpha=0.4)

The Model

It seems like there’s more than one standard way of moving between these ports. I make out three different routes; trajectories seem to be scattered around a central curve for each route.

We can model this behavior with a mixture of Gaussian processes. For each of the three different routes above (\(i=1,2,3\)), we’ll assume a function \(f_i\) mapping time to lattitude and longitude values was sampled from a Gaussian Process prior. Container ships will trajectories will always be close to one of these routes- we just don’t know which one. We’ll also assume that other (non-container) ships follow trajectories sampled independently from the same Gaussian Process prior. To tell if a given trajectory seems to be that of a container ship, we just need to check whether it more closely resembles the posterior mixture of container routes or prior over all possible ship trajectories.

Specifically, let \(y\) refer to the observed trajectories and \(y'\) refer to the new trajectory we’re trying to classify. Say \(B=0\) if \(y'\) is a container ship and \(B=1\) otherwise. We want to learn the posterior odds that \(B=1\) given \(y\) and \(y'\), which is just \(\frac{P(y' | y, B=1)P(B=1)}{P(y' | y, B=O)P(B=O)}\). When \(B=1\), we can get the likelihood of \(y'\) by integrating over samples \(f\) from the posterior GP mixture \(\int P(y' | f)P(f | y) \, df\). When \(B=0\), we can get the likelihood by integrating over samples \(g\) from the prior GP: \(\int P(y' | g)P(g) \, dg\).

We don’t know a priori what fraction of the ships in the region are container ships, but to be conservative, we’ll give equal prior probabiliy to \(B=0\) and \(B=1\).

Modeling Sub-Trajectories

Ideally, we’d like to identify out-of-distribution ships without having to observe their full port-to-port trajectories. To do this, we can marginalize the likelihood of a trajectory snippet over possible offsets in time.

function marginal_logprob(dist, val, t, σ2)
    starts = LinRange(0, 1 - t[end], 200)
    logsumexp([logpdf(dist(s .+ t, σ2), val) for s in starts])
end
marginal_logprob (generic function with 1 method)
function ok_prob(t, gp, y1p, y2p, y1, y2)
    """Find the posterior probability that a ship with longitude/latitude trajectory pairs (`y1`, `y2`) at times `t` is a container ship (with trajectories coming from posterior GPs `y1p` and `y2p`) rather than some other kind of ship (with trajectories coming from prior `gp`).
    """
    logistic(marginal_logprob(y1p, y1, t) + marginal_logprob(y2p, y2, t) -
             marginal_logprob(gp, y1, t) - marginal_logprob(gp, y2, t))
end
ok_prob (generic function with 1 method)

Approximating Posterior GP Mixtures

We can’t compute the posterior of a mixture of Gaussian processes analytically. But we can approximate it with mean field variational inference.

Let \(z_i\) be a 1-hot vector giving the route chosen in trajetory \(i\). I assume \(z_i \sim \text{Categorical}(\pi)\), and \(\pi \sim \text{Dirichlet}(\alpha_0)\). The variational posterior for \(\pi\) will be Dirichlet with parameters \(\alpha\).

For the likelihood, I’ll introduce inducing inputs \(c\) and outputs \(u_k\) for each route and assume that \(u_k = f_k(c)\) and \(p(y_i | u_k, z_{ik}) = \mathcal{N}(K_{yu} K_{uu}^{-1}u, Q_{yy})\) where \(Q_{yy} = \text{diag}(K_{yy} - K_{yu}K_{uu}^{-1}K_{uy}) + \sigma^2I\) (the Fully Indepenent Training Conditional assumption). Let \(A = K_{yu}K_{uu}^{-1}\) and \(Q = \Lambda^{-1}\). We can pre-calculate the kernel matrices and store them in a separate KernelMats struct for each trajectory. If we were more concered with performance, we would avoid finding \(K_{uu}^{-1}\) explicitly, but this is good enough for a blog post.

σ = 0.03
0.03
struct KernelMats
    Λ::PDiagMat{Float64,Vector{Float64}}
    A::Matrix{Float64}
end
@sizecheck function kernel_mats(kern, x_N, c_U)
    K_UU_inv = inv(PDMat(kern.(c_U', c_U)))
    k = kern(1, 1) # Assuming stationarity
    K_N = [kern.(c_U, reshape(x, (1, :))) for x in x_N]
    c = KernelMats[
        let
            K_UT = K_N[n]
            KernelMats(
                inv(PDiagMat(k .- diag(K_UT' * (K_UU_inv * K_UT)) .+ σ^2)),
                K_UT' * K_UU_inv)
        end for n in 1:N
    ]
    (K_UU_inv, c)
end
kernel_mats (generic function with 1 method)

The variational posterior for inducing points \(u_k\) will be normal with mean \(m_k\) and covariance matrix \(S_k\).

struct VComp
    S_inv::PDMat{Float64,Matrix{Float64}}
    m1::Vector{Float64}
    m2::Vector{Float64}
end

Now to figure out the variational updates. We’ll start by deriving the component update for \(q(z_i)\). From the mean field assumption, we know the ELBO is maximized when \(\log q(z_{ik}) = E_q \log p(y_i | u_k) + E_q \log p(z_{ik}) + \text{const}\). As the approximate posterior over \(\pi\) is Dirichlet, \(E_q \log p(z_{ik}) = E_q \pi_k = \psi(\alpha_k) - \psi(\sum_j \alpha_j)\). It remains to find

\(E_q \log p(y_i | u_k) =-\frac{1}{2} E_q (y-Au_k)\Lambda (y-Au_k) + \frac{1}{2} \log |\Lambda|\)

Expanding the quadratic term lets us compute the expectation:

\(-2 y^T\Lambda A m_k + y^T\Lambda y + \text{Tr}(E[u_ku_k^T])A^T \Lambda A\) \(= -2 y^T\Lambda A m_k + y^T\Lambda y + \text{Tr}(m_km_k^T + S_k)A^T \Lambda A\) \(= -2 y^T\Lambda A m_k + y^T\Lambda y + \text{Tr}(S_kA^T \Lambda A) + m_k^TA^T\Lambda A m_k\) \(= (y - Am_k)^T\Lambda (y - A m_k) + \text{Tr}(S_kA^T \Lambda A)\)

We can eliminate terms like \(\psi(\sum_j \alpha _j)\) and \(\log | \Lambda |\) which are the same for all components. The result is an expression for \(q(z_i = k)\), which I’ll write as \(r_{ik}\).

function responsibilities(y1, y2, c::Vector{KernelMats}, o, alpha)
    "Calculate r_ik for seeing `(y1, y2)` for each component in `o`"
    Vector{Float64}[
        softmax(Float64[isnothing(og) ? -Inf64 : let
            μ1 = cn.A * og.m1
            μ2 = cn.A * og.m2
            V = og.S_inv \ (cn.A' * cn.Λ * cn.A)
            q = quad(cn.Λ, y1n - μ1) + quad(cn.Λ, y2n - μ2)
            -0.5 * q - tr(V) + digamma(alpha_g)
        end for (og, alpha_g) in zip(o, alpha)])
        for (cn, y1n, y2n) in zip(c, y1, y2)]
end
responsibilities (generic function with 1 method)

Next, let’s calculate the inducing point variational parameters. From the mean field assumption, the ELBO is maximized when

\(\log q(u_k) = \sum_i E_q 1_{z_i = k} \log p(y_i | u_k, z_i = k) + \log p(u_k) + \text{const}\) \(= \sum_i r_{ik} (\langle u_ku_k^T, A^T\Lambda A \rangle + \langle \Lambda A u_k, y_i \rangle) + \langle u_ku_k^T, K_{uu}^{-1} \rangle + \text{const}\)

By combining like terms, we can see that \(S_k^{-1} = K_{uu}^{-1} + \sum_i r_{ik} A^T \Lambda A\) and \(S_k^{-1}m_k = \sum_i r_{ik} A^T \Lambda y_i\).

function fit_gp(g, K_UU_inv::PDMat{Float64,Matrix{Float64}}, r, y1, y2, c)
    if sum(rn -> rn[g], r) < 1e-8
        return nothing
    end
    S_inv = K_UU_inv + PDMat(sum(rn[g] * Xt_A_X(cn.Λ, cn.A) for (rn, cn) in zip(r, c) if rn[g] > 0))
    m1 = S_inv \ sum(rn[g] * cn.A' * (cn.Λ * yn) for (rn, cn, yn) in zip(r, c, y1))
    m2 = S_inv \ sum(rn[g] * cn.A' * (cn.Λ * yn) for (rn, cn, yn) in zip(r, c, y2))
    VComp(S_inv, m1, m2)
end
fit_gp (generic function with 1 method)
function fit_gps(K_UU_inv::PDMat{Float64,Matrix{Float64}}, r, y1, y2, c)::Vector{Union{Nothing,VComp}}
    Union{VComp,Nothing}[fit_gp(g, K_UU_inv, r, y1, y2, c) for g in 1:length(r[1])]
end
fit_gps (generic function with 1 method)

Finally, we get to the mixture weight paramers. As before,

\(\log q(\pi) = \log p(\pi) + E \log p(z_i | \pi) + \text{const}\) \(= \langle \alpha_0 - 1, \pi \rangle + \sum_i \langle r_i, \pi \rangle\)

This gives \(\alpha = \alpha_0 + \sum_i r_i\).

Putting everything together gives the following variational inference algorithm.

This uses an auxiliary function to compute fixed points.

function fixedpoint(f, arg::T; iters=500) where {T}
    for _ in 1:iters
        result = f(arg)::T
        max_diff = maximum(abs, result - arg)
        if max_diff < 1e-5
            return result
        end
        arg = result
    end
    println("DID NOT CONVERGE")
    arg
end
fixedpoint (generic function with 1 method)
@sizecheck function fit_mixture(kern, alpha0, x_N, y1_N, y2_N, z_M, o; iters=50)
    K_MM_inv, c = kernel_mats(kern, x_N, z_M)
    alpha = fixedpoint(alpha0; iters) do alpha
        r = responsibilities(y1_N, y2_N, c, o, alpha)
        o = fit_gps(K_MM_inv, r, y1_N, y2_N, c)
        alpha0 + sum(r)
    end
    alpha, c, o
end
fit_mixture (generic function with 1 method)

Example on Synthetic Data

kern = with_lengthscale(Matern52Kernel(), 0.1)
Matern 5/2 Kernel (metric = Distances.Euclidean(0.0))
    - Scale Transform (s = 10.0)
T = 100
100
N = 100
100
π_prior = Dirichlet(5 * ones(3))
Distributions.Dirichlet{Float64, Vector{Float64}, Float64}(alpha=[5.0, 5.0, 5.0])
gp = GP(kern)
GP{ZeroMean{Float64}, TransformedKernel{Matern52Kernel{Distances.Euclidean}, ScaleTransform{Float64}}}(ZeroMean{Float64}(), Matern 5/2 Kernel (metric = Distances.Euclidean(0.0))
    - Scale Transform (s = 10.0))
rng = Xoshiro(9)
Xoshiro(0x31c0fdb77e6b079f, 0xf8bbbedb8c20d31c, 0x1c19355ea0d34d01, 0x402b368a357b9496, 0x69a0c2eabd4f1212)
pi = rand(π_prior)
3-element Vector{Float64}:
 0.36892846174002913
 0.38153355876999534
 0.24953797948997547
(true_f, y1_N, y2_N, x_N) = let
    x_T = LinRange(0, 1, T)
    true_f = [rand(rng, gp(x_T), 2) for _ in 1:3]
    noise =.* randn(T, 2) for _ in 1:N]
    z = rand(rng, Distributions.Categorical(pi), N)
    y_T2N = stack(true_f[z] .+ noise)
    y1_N = eachcol(y_T2N[:, 1, :])
    y2_N = eachcol(y_T2N[:, 2, :])
    x_N = [x_T for _ in 1:N]
    (true_f, y1_N, y2_N, x_N)
end
([[0.39940744485764307 0.781570857858878; 0.512664898939086 0.9316931016687168; … ; 1.5971325077761558 0.384910967509904; 1.582860813305773 0.3279274081244852], [-0.472291406780798 -0.31707101959464334; -0.5771870148169773 -0.151971382440848; … ; 0.956562978197726 -0.04504026546424312; 0.8338497795432375 -0.07039399921429594], [1.7847815717824167 0.2774961462911457; 1.884786093600304 0.24311245888699062; … ; 0.4102229023901559 0.45159679562935084; 0.49950707039911835 0.5780682109962041]], SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}[[0.3700879617376666, 0.47788108272999624, 0.5525052930328936, 0.5676365000476474, 0.600580781383119, 0.656945621953472, 0.6775837278595784, 0.6881661800554684, 0.5653046086360503, 0.4880848580064354  …  1.7852554741039721, 1.7631940807466238, 1.7866280687546399, 1.6855743542791397, 1.633295886410844, 1.5593140529511285, 1.5809315473099828, 1.5850692917905742, 1.5749665860994762, 1.5334497225603498], [-0.4211210630553008, -0.6285051132552908, -0.6227617039329616, -0.8708783039278554, -0.9826350021129571, -1.1425221312002887, -1.2456052275455118, -1.339157460801622, -1.3290334978511795, -1.2908546745738376  …  1.7440585177057981, 1.642796584537638, 1.5471585069379599, 1.4042550243933754, 1.273922733201116, 1.2097211725456494, 1.1401774243829077, 1.0330166219774146, 0.9774660463719848, 0.8694560989261199], [-0.44252855724920725, -0.5767173309396507, -0.7518329940704416, -0.8264096060580315, -0.9813279633871161, -1.0735124841508843, -1.2444212804907984, -1.3061050624555488, -1.3969219450012467, -1.303032688942919  …  1.7683495188197973, 1.6419402128662237, 1.5761334309833228, 1.3606435788261928, 1.224679012415514, 1.1808031839489388, 1.1440982686727479, 1.038392382871553, 0.9647433487335163, 0.8530716930658208], [-0.47473269079831315, -0.5615750444633624, -0.7183606354693861, -0.8867156543723269, -0.9749576028015003, -1.1542788066262792, -1.189686860068581, -1.2503907540431287, -1.3637508843511998, -1.2417064682862893  …  1.7444528889141617, 1.6338139567309706, 1.5221667857459085, 1.3851123960074552, 1.2918327829204133, 1.1962757171128644, 1.1486952037078746, 1.034241253576947, 0.9431906285244502, 0.8771208457572213], [0.36494672244837023, 0.505601088156253, 0.5419318915414362, 0.6390307602149727, 0.6699930478719739, 0.6380325617871504, 0.7000116125595375, 0.619618434365539, 0.6466087167325566, 0.4645811083735421  …  1.7772023639080774, 1.7879075341332262, 1.8234619091398676, 1.624647272462175, 1.595123486550642, 1.5823015283466466, 1.6018578158912113, 1.5445446904908762, 1.5804156908057965, 1.5874978731660736], [-0.45829387216263934, -0.5845683157765975, -0.7693263706843804, -0.8577174415492629, -0.8834134626113572, -1.0496103351531525, -1.2095788920673365, -1.3298190726797585, -1.3324800229270173, -1.2100700907652808  …  1.8267688543625462, 1.6884903757111873, 1.4990146493192344, 1.3954822052838494, 1.2625187306115169, 1.2072624049503318, 1.1245814546757065, 1.0427897663355867, 0.9501295786797349, 0.8557545395659737], [0.3911978212505599, 0.47390431099975877, 0.5172472721731246, 0.6005713954613187, 0.6564877516019645, 0.7428362817154499, 0.6849315350068836, 0.6893463560608264, 0.6005205136139856, 0.516664551994273  …  1.742633293786382, 1.783338177103919, 1.7520259940736869, 1.6961588834929628, 1.6145302208050327, 1.5889162342065626, 1.5491919318064642, 1.6088224008108758, 1.6238059225322634, 1.6468580894377165], [-0.4991119968068169, -0.6049268868041415, -0.7045121931024152, -0.8772407092667843, -1.0096857469337286, -1.125173412712578, -1.2041341694080048, -1.3813047339822273, -1.3517693973120835, -1.2406432132068899  …  1.7890062258542998, 1.688037354400058, 1.55065835081178, 1.373678310091426, 1.2882506966993172, 1.1985096174552177, 1.170421838865118, 1.073173545610662, 0.9302561089260525, 0.8343017204831965], [1.7786400304805892, 1.932975938351668, 2.027516057090291, 2.0033981296517247, 2.0128784398831536, 1.9631861753128486, 1.88461564254782, 1.6353303715807372, 1.4095268547439055, 1.1729328342054135  …  0.07080188064212697, 0.012460862800967806, -0.003860472673403812, -0.027792527897081892, 0.041006054523241325, 0.07233487412832193, 0.13337486155989425, 0.2976129483379553, 0.38809138003591453, 0.4717183474687341], [-0.44791149133422614, -0.5810326796431938, -0.7487286401181614, -0.7742755300049207, -0.9809947843833565, -1.0816424737325367, -1.2030786640396176, -1.3559847041836666, -1.3741457632277978, -1.2648118142953142  …  1.7811394294186476, 1.631426421683647, 1.5579500986578252, 1.4672048846444257, 1.2712351531963988, 1.1810396356652466, 1.1468560415074445, 1.0999079643321066, 0.9236696931355612, 0.7942863307502518]  …  [-0.44594578234679483, -0.5964229389200998, -0.7157235648306179, -0.8420088977395895, -0.9939481689585593, -1.0712170449386966, -1.1831352801992991, -1.3793831458908126, -1.3668583344846437, -1.2624021712766755  …  1.7030451330999403, 1.6588017752536302, 1.512757674357593, 1.4013957525922651, 1.250702904939713, 1.2121393136416265, 1.0744324286427764, 1.0620879318473355, 0.9283018855601111, 0.8801629156596172], [-0.5047967910591429, -0.5466362008552932, -0.6642314745302432, -0.8665718388130023, -0.9770433234279923, -1.082748196037504, -1.2445321491565609, -1.3428682755147545, -1.37176085369222, -1.2222335484676774  …  1.7362298140257617, 1.6248308007159942, 1.5876133909496646, 1.3805173346437243, 1.2983694600460338, 1.1636778226273061, 1.1590428012588874, 1.0828843878659404, 0.9812034152139966, 0.8281145968431958], [0.3864372145405262, 0.47506544453546246, 0.5790311319771552, 0.605001195666856, 0.6642869684695955, 0.6572596859911383, 0.6458185110626002, 0.6932787719655424, 0.5834768194374723, 0.4210018183812101  …  1.825359709436106, 1.7637870052307099, 1.7882698835578315, 1.6336417412945545, 1.5968288216523252, 1.5397504369174955, 1.5972296799671601, 1.5732473167653882, 1.5750034773116124, 1.6086170561222364], [1.7652791784827582, 1.8992601889650986, 1.952025741075971, 1.975228312896439, 2.0285127685927513, 1.9170691010660572, 1.7523626215932344, 1.588547282576697, 1.4568152531165184, 1.11160332392917  …  0.05493937081390281, 0.049141727865529075, 0.00161747946027661, -0.04965613318054553, -0.03803713447194219, 0.10903602769992818, 0.1927704678808896, 0.2850016776623056, 0.4188130112268084, 0.4810023570547357], [0.4568436560110791, 0.48120454594583395, 0.571062776775225, 0.5952367679609637, 0.67959854911344, 0.6639152098144613, 0.6926944303221002, 0.6707664681350729, 0.624327313587915, 0.49367900246722796  …  1.7248769135622033, 1.7988909358878185, 1.7055663411581536, 1.7091054868956446, 1.5445615111365127, 1.6105424278886316, 1.6073818339426091, 1.558477019335945, 1.6083236847122333, 1.6061616139075827], [-0.4835813560919496, -0.5942356617704319, -0.6934005695853287, -0.8737173243718228, -0.9896088391321894, -1.1117975171234642, -1.1945413835901322, -1.319212417107135, -1.3414057714347458, -1.2392607356840089  …  1.7528161969646778, 1.6924694705607917, 1.5367129548531426, 1.3953555090582246, 1.2449245129761877, 1.2448055522759653, 1.1313157873096107, 1.0911204298246533, 0.9820455857015086, 0.8369014421300336], [-0.47315145964374994, -0.6072952347887302, -0.701880445688005, -0.8913588489475074, -0.9723306840541012, -1.1048488413578572, -1.2281168567356953, -1.3370998069464055, -1.345697912763073, -1.2529732938260187  …  1.8102388083417889, 1.6157632299169338, 1.5118503776166272, 1.4049830864638395, 1.2451653591009686, 1.2040886523321568, 1.1649931662817816, 1.1058982705777327, 0.9798737421629538, 0.8557887184426698], [1.7816870699107052, 1.8537112521580037, 1.976950190984634, 1.986261857263395, 2.021509615944646, 2.005696075937408, 1.8123978064589177, 1.6048856037735737, 1.41076537382959, 1.1644426850695577  …  0.057427651456858686, 0.028476859108208852, 0.0058752637035537035, -0.019816606578055063, -0.05569232459204474, 0.14836624405061954, 0.18946405785437778, 0.3061126113781384, 0.4227092896828641, 0.49661083058849514], [-0.5233359997666176, -0.5645079977227427, -0.6811099898375295, -0.8130532288194817, -0.9635666569689258, -1.1034908626409008, -1.2013648271891462, -1.3153184932607274, -1.3492353730582967, -1.325160707906475  …  1.7799405339485965, 1.6420547516008175, 1.5807903035712507, 1.4081405505375408, 1.2825112685705002, 1.244780947058204, 1.0927230918227466, 1.148161555289747, 0.9339069566655267, 0.8397286373318037], [-0.480714146455318, -0.5902724390160647, -0.6741421716500787, -0.8094161433217102, -0.9780975130073811, -1.0542478765511951, -1.213235664986352, -1.3681628231299032, -1.267384648323656, -1.2456436388976502  …  1.762646039608105, 1.6014173787887998, 1.563066220405517, 1.3719650753744121, 1.2539148216557845, 1.1904511189212659, 1.1052689281117576, 1.124542482409055, 0.9413495117813608, 0.8523346180358521]], SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}[[0.8026667343773839, 0.9272374973895329, 1.1127739414618427, 1.1850772393095752, 1.3755161827643632, 1.5145110873441292, 1.5451907301923051, 1.6238314911980516, 1.7107084933276706, 1.6891266615139446  …  0.37705151045419133, 0.3970301134461245, 0.5406376063300915, 0.571561218068671, 0.6800069366790182, 0.6246300248434518, 0.5549571657924839, 0.4845750398315211, 0.4444632386688463, 0.3065966944421674], [-0.3511202590582168, -0.1394202595527906, -0.01592128257509561, 0.11456141792709228, 0.20546059628586333, 0.22156494914935, 0.22615040939683548, 0.09025338096872909, 0.182187630160484, -0.005985131923862824  …  0.9453481027904888, 0.840024046255369, 0.7933974937600053, 0.6098280496150628, 0.45261882673612297, 0.28363201452859776, 0.1761962510005486, 0.044366780231433076, -0.042824779486057486, -0.07321518270210282], [-0.3316590172562221, -0.12882551009709525, 0.014541497404227253, 0.07032042159061222, 0.2176387514998076, 0.1811778304305366, 0.16815037466155341, 0.1407701910457701, 0.0996079066868828, -0.02930677781917293  …  0.9493178932539263, 0.8902953306505575, 0.7558815763551197, 0.6232956457134926, 0.4624333195116825, 0.27285330599692004, 0.13070915954548842, 0.06284622293823146, -0.044645324121477765, -0.08316048793996679], [-0.3061132070781288, -0.18746915091063898, -0.0037156318009084863, 0.11477712248293126, 0.14178840135194523, 0.19457909414717062, 0.1955899805257978, 0.11467999852483013, 0.11595901240894115, 0.022845111590977667  …  0.9469114758188042, 0.833763050840117, 0.7458581706309779, 0.5551756175360192, 0.4739637556591287, 0.29179344911648075, 0.16481176509486664, 0.06225125282634908, -0.04991623542449028, -0.05926327822357183], [0.7542643532313039, 0.9381647159092591, 1.0572712378171865, 1.230254771632905, 1.4011319305296153, 1.450573926965493, 1.52631013974342, 1.5834123511215288, 1.6286189371679543, 1.6881582903149  …  0.42886155919053387, 0.43128489920801094, 0.4842915816963676, 0.5582992436597269, 0.6702241932135719, 0.6383303005842197, 0.6041011943044065, 0.45254403753246025, 0.40008280432257515, 0.3345193631040302], [-0.33971839622565775, -0.18497161908851448, 0.012290318908119165, 0.03893645767118309, 0.1462335973747323, 0.2486427659633012, 0.18025218264555226, 0.15180563377655565, 0.129006648407872, 0.04482222095069195  …  0.9809957348161344, 0.8749881080800541, 0.746274443026718, 0.5541425643246481, 0.4625891124459483, 0.32162274339713603, 0.1872853342094712, 0.06036664390304627, -0.03680231544771183, -0.049510137247818095], [0.8070002191587647, 0.9442913066249085, 1.1302769386697824, 1.2081967521746397, 1.3356914974758014, 1.5223766554137672, 1.5909753470138102, 1.627331547618399, 1.6511501976480922, 1.6689231102993827  …  0.40538984256155297, 0.42861193633673994, 0.5015969281072589, 0.5311856631469392, 0.6674954615231076, 0.5965903598553703, 0.5566369513333743, 0.48324093467948537, 0.4311549502042953, 0.39076973414700705], [-0.28480705916843696, -0.10625315182533043, 0.024785261879050544, 0.07641654517612127, 0.1552347330393129, 0.21467197524252887, 0.214388936776746, 0.17167367047308044, 0.08858284863844713, -0.003576561190596985  …  1.0149570354303759, 0.9116976342059242, 0.8239318219863443, 0.6096597989327066, 0.48176526860875346, 0.32314738554439126, 0.16730003837914675, 0.08077671698648625, -0.07390351923490435, -0.05204669294629327], [0.28648197660419655, 0.20155607654827692, 0.26083569673153795, 0.36486042972579297, 0.47696577741860846, 0.5728627677913581, 0.7118394193366427, 0.8165500345643881, 0.8503989847697979, 0.8896777697420456  …  -1.4056923559021057, -1.3446808615941683, -1.1687235246679424, -0.8896726010400444, -0.5328105301816746, -0.21166465673454016, 0.032554311108303965, 0.2512420002094242, 0.40613003452988733, 0.5554062118802982], [-0.31588359240943087, -0.17187403815816005, -0.02484289353469079, 0.08707994300917316, 0.18861032983524167, 0.23601727846431259, 0.1828979770636666, 0.1949405573686691, 0.0929478161486642, 0.020004547787449693  …  0.9620080749546955, 0.8869241610481053, 0.7259735088196013, 0.5693908342304073, 0.4573272244950232, 0.29148148550445485, 0.1222081972384219, 0.02596855804006122, -0.04126537469085773, -0.08653472746798775]  …  [-0.35705120321208633, -0.20901635372498298, 0.02321892186912369, 0.11576829504557906, 0.14650371450550775, 0.20303895815636624, 0.24571110676730512, 0.11440248690470649, 0.07691455144487752, -0.010704949465951612  …  1.0079343494149848, 0.8639592791616786, 0.6989091242862185, 0.5875322078162816, 0.4186520837850273, 0.2654176591290303, 0.1682257940773366, 0.06676302288135211, -0.044769623955030355, -0.0788874095146267], [-0.279904108877613, -0.12295700233805448, 0.0071024562964638795, 0.08351957396009985, 0.1917727970633079, 0.1922179072599698, 0.19007885524486368, 0.11830436120149257, 0.11410074376829943, 0.01998378576031322  …  0.9492839117247923, 0.8589391035144648, 0.7319842854719563, 0.6128710414693419, 0.4817199565533437, 0.3135326814014038, 0.16696709064644785, 0.09259594733545032, -0.059887950660281526, -0.06310514982217307], [0.7485885481532815, 0.9496278327723022, 1.1355736085922767, 1.1970508579770986, 1.3664900305228709, 1.4789190585922352, 1.5562589096163317, 1.6778939629468463, 1.6972443353143774, 1.6889220543113057  …  0.37479976199667936, 0.47511857681857794, 0.5105806160436273, 0.5383790314420341, 0.6217977166256058, 0.6870037921444523, 0.5882077526873921, 0.4640753689514396, 0.3691364274267584, 0.32346671203108546], [0.28218547921788467, 0.25852326560035777, 0.24804168349375058, 0.34035215679584213, 0.48332091491032664, 0.6810745307736549, 0.7397047752201192, 0.8227716054680719, 0.9051939537338954, 0.9101865057623911  …  -1.412996682530946, -1.3325410232299726, -1.182947548555859, -0.8412079171641019, -0.5041422949557393, -0.16769397982582562, 0.10509389309414877, 0.33998724966109944, 0.5023973081720233, 0.5688581438571654], [0.7101443419242278, 0.92537579827458, 1.0610646275627824, 1.208902483146923, 1.4000777462059244, 1.4894610567046778, 1.5644128089198555, 1.621089004697623, 1.6888913837865571, 1.6464733511232061  …  0.3456885315155104, 0.4596129940542724, 0.491276304574797, 0.5198203237315869, 0.655724814227794, 0.693713107801062, 0.5775682053959397, 0.5010802378190798, 0.35285324550080654, 0.3119044310686243], [-0.36306238962978576, -0.12389784539665918, -0.03267561160049834, 0.07582850937766375, 0.19309351607550818, 0.22632673439955503, 0.22787902916734107, 0.18175550402040694, 0.06540796199682547, 0.04787935917086694  …  0.9164575597690792, 0.8931848866429338, 0.7148653413134657, 0.6040060300793832, 0.5191080406576657, 0.30390468880288385, 0.16387312554123587, 0.06095543112877319, -0.07994270150037766, -0.05720146267407376], [-0.32410490682718823, -0.1417689843419464, 0.04179619775600474, 0.16226839463733644, 0.21713888548744123, 0.2649215807550171, 0.20120446679707107, 0.17468738952420917, 0.10792588917593836, -0.01579405559071982  …  0.9321369006252102, 0.8769783256543988, 0.7465564760762197, 0.6192597526582085, 0.49699952114827506, 0.3324812875687691, 0.1674177074005299, 0.08781208984408945, -0.09395454139673826, -0.12054932874179738], [0.29020548407997343, 0.24367041136655856, 0.3140377175581981, 0.326030813256892, 0.4922290250390344, 0.69727486316066, 0.6715321969219319, 0.8220146448200999, 0.8571048995855263, 0.9569098957853442  …  -1.4085410061936572, -1.3139630581712771, -1.144940546367433, -0.8649584777180653, -0.5340106625375828, -0.15799823192358955, 0.10376369727576212, 0.3235305289744173, 0.4729629138433807, 0.6080658146372319], [-0.31922898415091433, -0.1277049942711259, -0.03651236916001316, 0.13697856902367483, 0.17011585663299533, 0.21585601992290315, 0.23131615828652724, 0.18219613085112296, 0.06861156243141214, 0.03269045900839065  …  0.9650255147908333, 0.9013952849286737, 0.7903142815081935, 0.602727673045968, 0.4284457361412203, 0.2691222965999577, 0.17383127894830167, 0.02768397406416313, -0.0550293602895119, -0.06250077994127747], [-0.30763151538794775, -0.11112087217424535, -0.04011655123221196, 0.1299296495772406, 0.15642780439202808, 0.20332251240533103, 0.23625236030045538, 0.14377374782200086, 0.06136200192502247, 0.009558288951302969  …  0.9637535656666396, 0.895903950309292, 0.7115397197932328, 0.6059992603639028, 0.40487763604619487, 0.29276415532274447, 0.17868614362400462, 0.07885875761328787, -0.044904922656329685, -0.07785071961322676]], LinRange{Float64, Int64}[LinRange{Float64}(0.0, 1.0, 100), LinRange{Float64}(0.0, 1.0, 100), LinRange{Float64}(0.0, 1.0, 100), LinRange{Float64}(0.0, 1.0, 100), LinRange{Float64}(0.0, 1.0, 100), LinRange{Float64}(0.0, 1.0, 100), LinRange{Float64}(0.0, 1.0, 100), LinRange{Float64}(0.0, 1.0, 100), LinRange{Float64}(0.0, 1.0, 100), LinRange{Float64}(0.0, 1.0, 100)  …  LinRange{Float64}(0.0, 1.0, 100), LinRange{Float64}(0.0, 1.0, 100), LinRange{Float64}(0.0, 1.0, 100), LinRange{Float64}(0.0, 1.0, 100), LinRange{Float64}(0.0, 1.0, 100), LinRange{Float64}(0.0, 1.0, 100), LinRange{Float64}(0.0, 1.0, 100), LinRange{Float64}(0.0, 1.0, 100), LinRange{Float64}(0.0, 1.0, 100), LinRange{Float64}(0.0, 1.0, 100)])
c_U = LinRange(0, 1, 25)
25-element LinRange{Float64, Int64}:
 0.0, 0.0416667, 0.0833333, 0.125, …, 0.875, 0.916667, 0.958333, 1.0

To kick off variational inference, we’ll need to guess starting parameters for \(S_k\) and \(m_k\).

o_guess = let
    f = gp(c_U, σ^2)
    K_UU_inv = inv(PDMat(kern.(c_U', c_U)))
    Union{Nothing,VComp}[VComp(K_UU_inv, rand(rng, f), rand(rng, f)) for _ in 1:3]
end
3-element Vector{Union{Nothing, VComp}}:
 VComp([7.454870343138737 -12.085564351326695 … -4.395025867196253e-8 1.0695053577639308e-8; -12.085564351326695 27.047547206323763 … 1.8060921559003263e-7 -4.395025867240513e-8; … ; -4.395025867196253e-8 1.8060921559003263e-7 … 27.04754720632348 -12.085564351326603; 1.0695053577639308e-8 -4.395025867240513e-8 … -12.085564351326603 7.454870343138712], [-0.32979921047879956, -0.5897305605112932, 0.11447706895399364, 1.029017128451943, 1.0599601362252413, 0.5111666085710849, 0.07907917386581927, 0.2877821470816543, 0.6470354741278084, 1.4784262694762869  …  -0.5740153604150344, -0.5371827926948518, -0.7993928322576462, -0.5568548810814099, -1.1369227588787811, -1.798383746675905, -1.2631768279691984, -0.7743985404177847, -0.5390795325431714, 0.46438075009635527], [0.43883495716887116, 0.9848746247385155, 0.9515814521504622, 0.7643492693358623, 0.5271321278211968, 0.4894869429738167, 0.8325602734008992, 1.3088984166300248, 1.8516330253814923, 1.8097893768525444  …  -1.2254495200013993, -0.9601158375877539, -0.46840313165530256, -0.46897044443131936, -0.48417534075567253, -0.5578469644611623, -1.0491654256770087, -1.4257330683614824, -0.8617948360702556, 0.017421172919114353])
 VComp([7.454870343138737 -12.085564351326695 … -4.395025867196253e-8 1.0695053577639308e-8; -12.085564351326695 27.047547206323763 … 1.8060921559003263e-7 -4.395025867240513e-8; … ; -4.395025867196253e-8 1.8060921559003263e-7 … 27.04754720632348 -12.085564351326603; 1.0695053577639308e-8 -4.395025867240513e-8 … -12.085564351326603 7.454870343138712], [-0.7744159166454945, -1.2601807812167034, -0.9958548872815728, -0.17653467402440767, 0.9430594849897121, 2.0105500724949295, 2.033124578960417, 1.8565564339339036, 1.8221137639819547, 1.1363083966179663  …  0.15448540477257527, 0.41259110078244066, -0.0027348811068554422, 0.19620549642147278, 0.8763233073663659, 1.3773182347415502, 1.2869812975779953, 0.7746707958819704, -0.022812002763165848, -0.5813990084339119], [0.7846997214502921, 0.9282255890427066, 0.9645535603585261, 0.6532360444039843, 0.3532747775171625, -0.06511150394942444, 0.38443523877237284, 0.9039526689577165, 0.8965788941631468, 0.3359702471738463  …  -0.34629481463918804, -0.8349062772444377, -0.5182156620975424, -0.044966742684170016, 0.5312907147853932, 0.5535393563291215, 0.9294317565413944, 1.479226205158952, 1.5475181642383982, 0.8477215849890636])
 VComp([7.454870343138737 -12.085564351326695 … -4.395025867196253e-8 1.0695053577639308e-8; -12.085564351326695 27.047547206323763 … 1.8060921559003263e-7 -4.395025867240513e-8; … ; -4.395025867196253e-8 1.8060921559003263e-7 … 27.04754720632348 -12.085564351326603; 1.0695053577639308e-8 -4.395025867240513e-8 … -12.085564351326603 7.454870343138712], [1.7486689273216505, 2.4259674782795706, 2.23590574897882, 1.3197002380068752, 0.15255298167106107, -0.13640229167983647, -0.2978069479132432, -0.2357749372929084, -0.480667401701669, -0.3200016176917847  …  0.30610173054220047, 0.7773459629412118, 0.6983570030055796, 0.3619794114301518, 0.23999594112184505, -0.1174427074074989, -0.750398403767938, -1.5117006503169192, -1.4472670583060792, -1.0947027562950333], [0.6884887829721742, 1.0012706667133673, 1.4795170415409875, 1.1715523761660434, 1.059851802972681, 0.47428421889003575, 0.5799475934607595, 1.2958003220887966, 1.4223610253305643, 1.4728308154550735  …  -0.37139044454526404, 0.48014979684391895, 0.39591942033965993, -0.49328143269907426, -0.18136582469922452, 0.2431903591430569, 0.4586852603345645, -0.24386642954148235, -0.6263680839326332, -0.6205153515060421])
_, c, o = fit_mixture(kern, π_prior.alpha, x_N, y1_N, y2_N, c_U, o_guess; iters=50)
([33.0, 42.0, 40.0], KernelMats[KernelMats([1111.111111111111 0.0 … 0.0 0.0; 0.0 338.3030875290355 … 0.0 0.0; … ; 0.0 0.0 … 338.30308752883224 0.0; 0.0 0.0 … 0.0 1111.1111111071364], [0.9999999999999987 -5.1769817281390405e-15 … -2.0285485224130183e-19 5.1985138724645184e-20; 0.7242230000833083 0.38463921309029714 … 5.850640797546697e-10 -1.4237212403844214e-10; … ; -1.423721240887343e-10 5.850640799372934e-10 … 0.38463921309030774 0.7242230000833062; 6.253485430900889e-22 -1.036291871406433e-20 … 3.552713678800501e-15 0.9999999999999964]), KernelMats([1111.111111111111 0.0 … 0.0 0.0; 0.0 338.3030875290355 … 0.0 0.0; … ; 0.0 0.0 … 338.30308752883224 0.0; 0.0 0.0 … 0.0 1111.1111111071364], [0.9999999999999987 -5.1769817281390405e-15 … -2.0285485224130183e-19 5.1985138724645184e-20; 0.7242230000833083 0.38463921309029714 … 5.850640797546697e-10 -1.4237212403844214e-10; … ; -1.423721240887343e-10 5.850640799372934e-10 … 0.38463921309030774 0.7242230000833062; 6.253485430900889e-22 -1.036291871406433e-20 … 3.552713678800501e-15 0.9999999999999964]), KernelMats([1111.111111111111 0.0 … 0.0 0.0; 0.0 338.3030875290355 … 0.0 0.0; … ; 0.0 0.0 … 338.30308752883224 0.0; 0.0 0.0 … 0.0 1111.1111111071364], [0.9999999999999987 -5.1769817281390405e-15 … -2.0285485224130183e-19 5.1985138724645184e-20; 0.7242230000833083 0.38463921309029714 … 5.850640797546697e-10 -1.4237212403844214e-10; … ; -1.423721240887343e-10 5.850640799372934e-10 … 0.38463921309030774 0.7242230000833062; 6.253485430900889e-22 -1.036291871406433e-20 … 3.552713678800501e-15 0.9999999999999964]), KernelMats([1111.111111111111 0.0 … 0.0 0.0; 0.0 338.3030875290355 … 0.0 0.0; … ; 0.0 0.0 … 338.30308752883224 0.0; 0.0 0.0 … 0.0 1111.1111111071364], [0.9999999999999987 -5.1769817281390405e-15 … -2.0285485224130183e-19 5.1985138724645184e-20; 0.7242230000833083 0.38463921309029714 … 5.850640797546697e-10 -1.4237212403844214e-10; … ; -1.423721240887343e-10 5.850640799372934e-10 … 0.38463921309030774 0.7242230000833062; 6.253485430900889e-22 -1.036291871406433e-20 … 3.552713678800501e-15 0.9999999999999964]), KernelMats([1111.111111111111 0.0 … 0.0 0.0; 0.0 338.3030875290355 … 0.0 0.0; … ; 0.0 0.0 … 338.30308752883224 0.0; 0.0 0.0 … 0.0 1111.1111111071364], [0.9999999999999987 -5.1769817281390405e-15 … -2.0285485224130183e-19 5.1985138724645184e-20; 0.7242230000833083 0.38463921309029714 … 5.850640797546697e-10 -1.4237212403844214e-10; … ; -1.423721240887343e-10 5.850640799372934e-10 … 0.38463921309030774 0.7242230000833062; 6.253485430900889e-22 -1.036291871406433e-20 … 3.552713678800501e-15 0.9999999999999964]), KernelMats([1111.111111111111 0.0 … 0.0 0.0; 0.0 338.3030875290355 … 0.0 0.0; … ; 0.0 0.0 … 338.30308752883224 0.0; 0.0 0.0 … 0.0 1111.1111111071364], [0.9999999999999987 -5.1769817281390405e-15 … -2.0285485224130183e-19 5.1985138724645184e-20; 0.7242230000833083 0.38463921309029714 … 5.850640797546697e-10 -1.4237212403844214e-10; … ; -1.423721240887343e-10 5.850640799372934e-10 … 0.38463921309030774 0.7242230000833062; 6.253485430900889e-22 -1.036291871406433e-20 … 3.552713678800501e-15 0.9999999999999964]), KernelMats([1111.111111111111 0.0 … 0.0 0.0; 0.0 338.3030875290355 … 0.0 0.0; … ; 0.0 0.0 … 338.30308752883224 0.0; 0.0 0.0 … 0.0 1111.1111111071364], [0.9999999999999987 -5.1769817281390405e-15 … -2.0285485224130183e-19 5.1985138724645184e-20; 0.7242230000833083 0.38463921309029714 … 5.850640797546697e-10 -1.4237212403844214e-10; … ; -1.423721240887343e-10 5.850640799372934e-10 … 0.38463921309030774 0.7242230000833062; 6.253485430900889e-22 -1.036291871406433e-20 … 3.552713678800501e-15 0.9999999999999964]), KernelMats([1111.111111111111 0.0 … 0.0 0.0; 0.0 338.3030875290355 … 0.0 0.0; … ; 0.0 0.0 … 338.30308752883224 0.0; 0.0 0.0 … 0.0 1111.1111111071364], [0.9999999999999987 -5.1769817281390405e-15 … -2.0285485224130183e-19 5.1985138724645184e-20; 0.7242230000833083 0.38463921309029714 … 5.850640797546697e-10 -1.4237212403844214e-10; … ; -1.423721240887343e-10 5.850640799372934e-10 … 0.38463921309030774 0.7242230000833062; 6.253485430900889e-22 -1.036291871406433e-20 … 3.552713678800501e-15 0.9999999999999964]), KernelMats([1111.111111111111 0.0 … 0.0 0.0; 0.0 338.3030875290355 … 0.0 0.0; … ; 0.0 0.0 … 338.30308752883224 0.0; 0.0 0.0 … 0.0 1111.1111111071364], [0.9999999999999987 -5.1769817281390405e-15 … -2.0285485224130183e-19 5.1985138724645184e-20; 0.7242230000833083 0.38463921309029714 … 5.850640797546697e-10 -1.4237212403844214e-10; … ; -1.423721240887343e-10 5.850640799372934e-10 … 0.38463921309030774 0.7242230000833062; 6.253485430900889e-22 -1.036291871406433e-20 … 3.552713678800501e-15 0.9999999999999964]), KernelMats([1111.111111111111 0.0 … 0.0 0.0; 0.0 338.3030875290355 … 0.0 0.0; … ; 0.0 0.0 … 338.30308752883224 0.0; 0.0 0.0 … 0.0 1111.1111111071364], [0.9999999999999987 -5.1769817281390405e-15 … -2.0285485224130183e-19 5.1985138724645184e-20; 0.7242230000833083 0.38463921309029714 … 5.850640797546697e-10 -1.4237212403844214e-10; … ; -1.423721240887343e-10 5.850640799372934e-10 … 0.38463921309030774 0.7242230000833062; 6.253485430900889e-22 -1.036291871406433e-20 … 3.552713678800501e-15 0.9999999999999964])  …  KernelMats([1111.111111111111 0.0 … 0.0 0.0; 0.0 338.3030875290355 … 0.0 0.0; … ; 0.0 0.0 … 338.30308752883224 0.0; 0.0 0.0 … 0.0 1111.1111111071364], [0.9999999999999987 -5.1769817281390405e-15 … -2.0285485224130183e-19 5.1985138724645184e-20; 0.7242230000833083 0.38463921309029714 … 5.850640797546697e-10 -1.4237212403844214e-10; … ; -1.423721240887343e-10 5.850640799372934e-10 … 0.38463921309030774 0.7242230000833062; 6.253485430900889e-22 -1.036291871406433e-20 … 3.552713678800501e-15 0.9999999999999964]), KernelMats([1111.111111111111 0.0 … 0.0 0.0; 0.0 338.3030875290355 … 0.0 0.0; … ; 0.0 0.0 … 338.30308752883224 0.0; 0.0 0.0 … 0.0 1111.1111111071364], [0.9999999999999987 -5.1769817281390405e-15 … -2.0285485224130183e-19 5.1985138724645184e-20; 0.7242230000833083 0.38463921309029714 … 5.850640797546697e-10 -1.4237212403844214e-10; … ; -1.423721240887343e-10 5.850640799372934e-10 … 0.38463921309030774 0.7242230000833062; 6.253485430900889e-22 -1.036291871406433e-20 … 3.552713678800501e-15 0.9999999999999964]), KernelMats([1111.111111111111 0.0 … 0.0 0.0; 0.0 338.3030875290355 … 0.0 0.0; … ; 0.0 0.0 … 338.30308752883224 0.0; 0.0 0.0 … 0.0 1111.1111111071364], [0.9999999999999987 -5.1769817281390405e-15 … -2.0285485224130183e-19 5.1985138724645184e-20; 0.7242230000833083 0.38463921309029714 … 5.850640797546697e-10 -1.4237212403844214e-10; … ; -1.423721240887343e-10 5.850640799372934e-10 … 0.38463921309030774 0.7242230000833062; 6.253485430900889e-22 -1.036291871406433e-20 … 3.552713678800501e-15 0.9999999999999964]), KernelMats([1111.111111111111 0.0 … 0.0 0.0; 0.0 338.3030875290355 … 0.0 0.0; … ; 0.0 0.0 … 338.30308752883224 0.0; 0.0 0.0 … 0.0 1111.1111111071364], [0.9999999999999987 -5.1769817281390405e-15 … -2.0285485224130183e-19 5.1985138724645184e-20; 0.7242230000833083 0.38463921309029714 … 5.850640797546697e-10 -1.4237212403844214e-10; … ; -1.423721240887343e-10 5.850640799372934e-10 … 0.38463921309030774 0.7242230000833062; 6.253485430900889e-22 -1.036291871406433e-20 … 3.552713678800501e-15 0.9999999999999964]), KernelMats([1111.111111111111 0.0 … 0.0 0.0; 0.0 338.3030875290355 … 0.0 0.0; … ; 0.0 0.0 … 338.30308752883224 0.0; 0.0 0.0 … 0.0 1111.1111111071364], [0.9999999999999987 -5.1769817281390405e-15 … -2.0285485224130183e-19 5.1985138724645184e-20; 0.7242230000833083 0.38463921309029714 … 5.850640797546697e-10 -1.4237212403844214e-10; … ; -1.423721240887343e-10 5.850640799372934e-10 … 0.38463921309030774 0.7242230000833062; 6.253485430900889e-22 -1.036291871406433e-20 … 3.552713678800501e-15 0.9999999999999964]), KernelMats([1111.111111111111 0.0 … 0.0 0.0; 0.0 338.3030875290355 … 0.0 0.0; … ; 0.0 0.0 … 338.30308752883224 0.0; 0.0 0.0 … 0.0 1111.1111111071364], [0.9999999999999987 -5.1769817281390405e-15 … -2.0285485224130183e-19 5.1985138724645184e-20; 0.7242230000833083 0.38463921309029714 … 5.850640797546697e-10 -1.4237212403844214e-10; … ; -1.423721240887343e-10 5.850640799372934e-10 … 0.38463921309030774 0.7242230000833062; 6.253485430900889e-22 -1.036291871406433e-20 … 3.552713678800501e-15 0.9999999999999964]), KernelMats([1111.111111111111 0.0 … 0.0 0.0; 0.0 338.3030875290355 … 0.0 0.0; … ; 0.0 0.0 … 338.30308752883224 0.0; 0.0 0.0 … 0.0 1111.1111111071364], [0.9999999999999987 -5.1769817281390405e-15 … -2.0285485224130183e-19 5.1985138724645184e-20; 0.7242230000833083 0.38463921309029714 … 5.850640797546697e-10 -1.4237212403844214e-10; … ; -1.423721240887343e-10 5.850640799372934e-10 … 0.38463921309030774 0.7242230000833062; 6.253485430900889e-22 -1.036291871406433e-20 … 3.552713678800501e-15 0.9999999999999964]), KernelMats([1111.111111111111 0.0 … 0.0 0.0; 0.0 338.3030875290355 … 0.0 0.0; … ; 0.0 0.0 … 338.30308752883224 0.0; 0.0 0.0 … 0.0 1111.1111111071364], [0.9999999999999987 -5.1769817281390405e-15 … -2.0285485224130183e-19 5.1985138724645184e-20; 0.7242230000833083 0.38463921309029714 … 5.850640797546697e-10 -1.4237212403844214e-10; … ; -1.423721240887343e-10 5.850640799372934e-10 … 0.38463921309030774 0.7242230000833062; 6.253485430900889e-22 -1.036291871406433e-20 … 3.552713678800501e-15 0.9999999999999964]), KernelMats([1111.111111111111 0.0 … 0.0 0.0; 0.0 338.3030875290355 … 0.0 0.0; … ; 0.0 0.0 … 338.30308752883224 0.0; 0.0 0.0 … 0.0 1111.1111111071364], [0.9999999999999987 -5.1769817281390405e-15 … -2.0285485224130183e-19 5.1985138724645184e-20; 0.7242230000833083 0.38463921309029714 … 5.850640797546697e-10 -1.4237212403844214e-10; … ; -1.423721240887343e-10 5.850640799372934e-10 … 0.38463921309030774 0.7242230000833062; 6.253485430900889e-22 -1.036291871406433e-20 … 3.552713678800501e-15 0.9999999999999964]), KernelMats([1111.111111111111 0.0 … 0.0 0.0; 0.0 338.3030875290355 … 0.0 0.0; … ; 0.0 0.0 … 338.30308752883224 0.0; 0.0 0.0 … 0.0 1111.1111111071364], [0.9999999999999987 -5.1769817281390405e-15 … -2.0285485224130183e-19 5.1985138724645184e-20; 0.7242230000833083 0.38463921309029714 … 5.850640797546697e-10 -1.4237212403844214e-10; … ; -1.423721240887343e-10 5.850640799372934e-10 … 0.38463921309030774 0.7242230000833062; 6.253485430900889e-22 -1.036291871406433e-20 … 3.552713678800501e-15 0.9999999999999964])], Union{Nothing, VComp}[VComp([38046.26445509075 5325.464176322597 … 8.006537288460481e-5 -2.1740008405797687e-5; 5325.464176322597 62727.21482930112 … -0.00029091409508395566 8.006537289784688e-5; … ; 8.006537288460481e-5 -0.00029091409508395566 … 62727.214828988595 5325.464176318087; -2.1740008405797687e-5 8.006537289784688e-5 … 5325.464176318087 38046.264454973796], [1.7884384433739497, 2.0078691637258808, 1.3336730045136205, 0.27314880593905083, 0.11717217357332622, 0.7608730047237832, 1.418507139298858, 1.5047548465666285, 1.2232513340114586, 0.5870425288598827  …  -0.5090363917647711, 0.4159266785825978, 1.142011701120123, 0.3991566630081422, -0.03828048002571067, 0.45909901400871095, 0.30936063082589194, 0.003078292861633498, 0.07752886898789478, 0.4972970076652429], [0.251311327227621, 0.4888737351135423, 0.8867746410121203, 1.086589015474623, 1.359889255567608, 1.1663408298950968, 0.6742235065052071, 0.43683503392133255, 0.50347745629217, 0.28420664577352606  …  -1.3913449608047943, -1.4820787109705316, -1.0020681941677192, 0.015795349003787432, 0.354779181378556, -0.12879054552524927, -1.0416735978677272, -1.3667321599783797, -0.24280530297126435, 0.585944704826394]), VComp([50273.02467875962 7041.105164396358 … 0.00010581479818065806 -2.8731305946311186e-5; 7041.105164396358 82880.84002711735 … -0.0003844802500373811 0.0001058147981981566; … ; 0.00010581479818065806 -0.0003844802500373811 … 82880.84002670436 7041.105164390399; -2.8731305946311186e-5 0.0001058147981981566 … 7041.105164390399 50273.02467860508], [-0.4836135038178731, -0.988678978737391, -1.3229272605434894, -0.32482102645182676, 0.7706544694426579, 0.8325822141583343, 0.5987229795637262, -0.5241990119238615, -1.9107531261540993, -2.4820153872991875  …  -0.6464292997148956, -0.5574036888828474, -0.3351650246270802, 0.1112275857600006, 0.32323019350882154, 0.8085350087723989, 1.5858869985684616, 1.6837522737163038, 1.20160003294185, 0.8592494110733625], [-0.31784821189198365, 0.1986421274795545, 0.07395072617986106, -0.16580134440202543, -0.22344252696654085, 0.1827130996591008, 1.0423472580033715, 1.6508726669189218, 1.8051465608388193, 1.3977140239824084  …  0.12167854454672755, -0.11185061677671573, -0.10965042590176122, 0.19013484987004253, 0.8374248207113852, 1.2895245916656173, 1.2016737602722332, 0.9172989551111528, 0.31352000354086257, -0.07552767910043261]), VComp([47555.96685127765 6659.851611491078 … 0.000100092703670424 -2.717768427064152e-5; 6659.851611491078 78402.25664982485 … -0.0003636877711588421 0.00010009270368697667; … ; 0.000100092703670424 -0.0003636877711588421 … 78402.25664943419 6659.851611485441; -2.717768427064152e-5 0.00010009270368697667 … 6659.851611485441 47555.96685113146], [0.4066548669398225, 0.6527573699038731, 0.5816931561575486, 0.39017907478916863, 0.9103386003926413, 0.8994363236523513, 0.40972139186328504, -0.26309126872025157, -1.0235610420699035, -1.2045025137506382  …  2.3605910456886523, 1.6760297054286803, 0.8655616471928317, -0.3470911540061117, -0.7543098221945449, 0.12175262357027786, 1.4032123448577127, 1.7844745477706296, 1.5926862118501055, 1.5854845568338483], [0.776847463305056, 1.38569216887199, 1.6662121732024364, 1.641756303077532, 1.5222001502924467, 0.9846669635716546, 0.8003230315970344, 0.39959120493513467, 0.302546033373993, 0.6610598848670328  …  -1.9937429281585854, -2.5432907386447496, -2.606678025607028, -1.921237570472245, -1.0458519688106935, -0.4689228809106339, 0.07851774467505346, 0.4218121455453776, 0.6476437583638489, 0.3185399115368512])])

We can visualize how well the mixtures were recovered during inference by plotting a circle at our GP predictions over time. The radius of each circle will be the standard deviation of our posterior uncertainty at that point. I will draw the true synthetically generated functions as well for comparison.

function plot_example(o, kern, c_U, true_f)
    f = Figure()
    ax = Axis(f[1, 1])
    x_N = [LinRange(0, 1, 100)]
    _, c = kernel_mats(kern, x_N, c_U)
    for on in filter(!isnothing, o)
        μ = Point2d.(c[1].A * on.m1, c[1].A * on.m2)
        σ = sqrt.(diag(inv(c[1].Λ)) .+ invquad(on.S_inv, c[1].A'))
        poly!(ax, Circle.(μ, σ), alpha=0.2)
    end
    for i in 1:3
        lines!(ax, true_f[i][:, 1], true_f[i][:, 2])
    end
    f
end
plot_example (generic function with 1 method)
plot_example(o, kern, c_U, true_f)