Seeing Bayesian Optimization in Action: An Animated MATLAB Example

5 minute read

January 14, 2024

Bayesian optimization is an iterative process that starts with prior beliefs about the objective function to be estimated, including its smoothness and other characteristics. Over time, it collects more evidence through an acquisition function to refine its initial beliefs about the objective function.

To implement the Bayesian optimization algorithm, the first step is to choose an objective function that accurately represents the system to be modeled and optimized. Typically, the Gaussian process model is employed, and it has also been utilized for this demo.

Gaussian Process

A Gaussian process model is described with mean $(\mu (x))$ and covariance $(k(x,x’))$ functions. Then, the objective function is represented as:

$ f(x) \sim GP(\mu(x), k(x,x’))$

Initially, the mean value is assumed to be zero. When a new observation is made, both the mean and the covariance will be updated using the equations of the posterior distribution function.

The covariance function $k(x,x’)$ is also called a kernel function. The common kernel function to model a Gaussian process is an exponential quadratic function represented as follows.

$ K(X_1,X_2)=\sigma^2 \left (-\frac{||X_1 - X_2||^2}{2l^2} \right )$
Initially, the covariance function is computed from the hyperparameter values $\sigma$, representing the variance in the model, and $l$, which also represents the length scale. $X_1$ and $X_2$ are input variables for 2D search space. The covariance (kernel) function can be implemented in MATLAB code as given below.

% For 1-D
function f = kernel_x(X,l,sig_var)
    distance = squareform(pdist(X));
    f = sig_var*exp(-distance.^2./(2*l^2));
end
% For 2-D
function f = kernel_xy(x, y, l, sigma)
    distance = pdist2(x,y);
    f= sigma*exp(-distance.^2./(2*l^2));
end

The next step involves the implementation of the acquisition function. There are a number of acquisition functions. For this demo, expected improvement is selected.

Acquisition Function

\(\text{EI}(x;\xi) = \begin{cases} \left(f(x^\star) - \mu(x) - \xi\right) \Phi\left(\frac{f(x^\star) - \mu(x)-\xi}{\sigma(x)}\right) + \sigma(x) \phi\left(\frac{f(x^\star) - \mu(x) -\xi}{\sigma(x)}\right) , & \quad \sigma(x) > 0 \\ 0, & \quad \sigma(x) \leq 0 \end{cases}\) where,
EI $\rightarrow$ expected improvement
$\xi$ $\rightarrow$ exploration-exploitation tradeoff parameter
$\phi$ $\rightarrow$ normal probability distribution function
$\Phi$ $\rightarrow$ normal cumulative distribution function
$\sigma$ $\rightarrow$ expected mean uncertainty
$f(x^*)$ $\rightarrow$ optimum value found so far

The above equation can be written in compact form \(EI(z) = \begin{cases} \sigma(x) \left( z \Phi\left( z \right) + \phi\left(z\right) \right) , & \quad \sigma(x) > 0 \\ 0, & \quad \sigma(x) \leq 0 \end{cases}\)

where
$z = \frac{d}{\sigma(x)}$
$d=f(x^\star) - \mu(x)-\xi $

MATLAB implementation:

function [EI,phi,PHI,z,idx] = aqufun(ytrain, mu, stdv,i)
    % Exploration-exploitation parameter (greek letter, xi)
    % High xi = more exploration
    % Low xi = more exploitation (can be < 0)
    if(i==1)
        xi = -0.2;
    else
        xi=0.02;
    end
    [f_star,idx]=min(ytrain);
    d = f_star- mu - xi; % (f* - y) if minimiziation
    z = d./stdv;
    phi = normpdf(z);
    PHI = normcdf(z);

    EI = (stdv ~= 0).*(d.*PHI + stdv.*phi);
end 

Model Update

The joint distribution of the training output, $f$, and test data output, $f_\star$, with noisy observation is described as follows:

\[\begin{bmatrix} y \\ f_\star\end{bmatrix} = \mathcal{N} \begin{pmatrix}0,\begin{bmatrix} K(X,X) + \sigma_n^2I & K(X,X_\star) \\ K(X_\star,X) & K(X_\star,X_\star) \end{bmatrix}\end{pmatrix}\]

A conditional probability is applied to obtain the posterior distribution over a function that agrees with the observed data points.

\[f_\star|X,y,X_\star \sim \mathcal{N}(\bar{f_\star}, cov(f_\star)),\]

where
\(\bar{f_\star} \overset{\Delta}{=} \mathbb{E}[f_\star|X,y,X_\star] = K(X_\star,X)[K(X,X) + \sigma_n^2I]^{-1}y \\ cov(f_\star) = K(X_\star,X_\star) - K(X_\star,X)[K(X,X)+\sigma_n^2I]^{-1}K(X,X_\star)\)

The MATLAB Implementation :

function[mu,cov,stdv] = GP(xtrain, ytrain,xtest, l, sig_var, noise_var)
    n_t = length(xtrain);
    K_ss = kernel_xy(xtest,xtest,l,sig_var);
    K    = kernel_xy(xtrain,xtrain,l,sig_var);
    L    = chol(K + noise_var*eye(n_t),'lower');
    K_s  = kernel_xy(xtrain,xtest,l,sig_var);
    alpha = L'\(L\ytrain);
    mu = K_s'*alpha;
    v = L\K_s;
    cov = K_ss - v'*v;
    stdv = sqrt(diag(cov));

end

Full MATLAB Code

For this demonstration, observation values are sampled from a sine function. Consequently, the objective function is expected to approximate a sine wave as the number of observations increases.

close all
clc
l = 1;
sig_var = 1;
xtrain = [-4,-3, -2, -1, 1]';
noise = 0.2;
ytrain = sin(xtrain);

% "Test" data ( what we want to estimate/ interpolate)
n_s = 500;
xtest = transpose(linspace(-5,5,n_s));
% Create plot 
figure(1)
set(gcf, 'WindowState', 'maximized');
for i=1:10
    [mu,cov,stdv]=GP(xtrain,ytrain,xtest,l,sig_var,0);
    stdv = real(stdv);
    [EI, phi, PHI, z, idx] = aqufun(ytrain, mu, stdv,i);
    [eimax,posEI] = max(EI);
    xEI = xtest(posEI,:);
    xtrain(end+1,:) = xEI;               % Save xEI as next
    ytrain(end+1) = sin(xEI);

    f_above = mu+2*stdv;
    f_below = mu-2*stdv;
    xx = xtest;
    if(i==1)
        test_pos =205;
    else
        test_pos = posEI;%205;
    end
    ypdf = normpdf(xtest);
    ycdf = normcdf(xtest);
    ff = [f_below; flip(f_above,1)];
    subplot(211)
    fill([xx;flip(xx,1)],ff,[7 7 7]/8)
    ylim([-4 4]);
    xlim([-5 5]);
    hold on
    plot(xtest,sin(xtest),'-b','LineWidth',2)
    scatter(xtrain,ytrain,'blue','filled','diamond')
    plot(xtest,mu,'r--','LineWidth',2)
   
    plot(xtest,z+mu(test_pos,1),'--g')
    if (i==1)
        plot(ypdf+xtest(test_pos,1),xtest+mu(test_pos,1));
        plot(ycdf+xtest(test_pos,1),xtest+mu(test_pos,1));

        %
        fill([xtest(test_pos,1);xtest(test_pos,1);flip(ycdf(find(xtest<=z(test_pos,1)),:),1)+xtest(test_pos,1);xtest(test_pos,1)],...
            [xtest(1,1)+mu(test_pos,1);z(test_pos,1)+mu(test_pos,1);flip(xtest(find(xtest<=z(test_pos,1)),:),1)+mu(test_pos,1);...
            xtest(1,1)+mu(test_pos,1)],[0.8500 0.3250 0.0980],'FaceAlpha',0.3)
    else
        if(test_pos < 250)
            quiver(xtest(test_pos+10,1), -2, -0.15, -1.9, 0, 'r','MaxHeadSize', 0.8,'LineWidth', 4);
        else
            quiver(xtest(test_pos-10,1), -2, 0.15, -1.9, 0,'r','MaxHeadSize', 0.8, 'LineWidth', 4);
        end
    end
    px = [-5 5];
    py = [ytrain(idx) ytrain(idx)];
    px1 = [xtrain(idx) xtrain(idx)];
    px2 = [xtest(test_pos,1) xtest(test_pos,1)];
    py1 = [-5 ytrain(idx)];
    py2 = [-5 5];
    line(px,py,'LineStyle','--','LineWidth',2);
    line(px1,py1,'LineStyle','--','LineWidth',2);
    if (i==1)
        line(px2,py2,'Color','#D95319','LineStyle','-');
    end
    title("Data fitting process with Bayesian optimization")
    xlabel("x")
    ylabel("f(x)")
    legend("$\pm$ $\sigma(x)$ region","true function","observed points","$\mu$ (x)", ...
        "$z = (f(x^\star) - \mu (x)-\xi)/\sigma$",'Location', 'northeast','Interpreter','latex');

    if(i==1)

        % Create text
        text('FontWeight','bold','String','z(x_{test}) \rightarrow',...
            'Position',[-1.516111793479424,0.842299072270407,0]);
        % Create textarrow
        xshift = 0.15;
        yshift = -0.2;
        annotation(gcf,'textarrow',[0.498828125 0.518750000000004],...
            [0.828030303030303 0.795454545454545],'String',{'$\Phi(z)$'},...
            'HorizontalAlignment','center','Interpreter','latex');

        % Create textarrow
        annotation(gcf,'textarrow',[0.429296875 0.4515625],...
            [0.834090909090909 0.796969696969697],'String','$\phi(z)$',...
            'HorizontalAlignment','center','Interpreter','latex');

        % Create textbox
        annotation(gcf,'textbox',...
            [0.343080357142864,0.545937785766092,0.011997767857138,0.027858961826228],...
            'String','x^*',...
            'FontWeight','bold',...
            'FontSize',12,...
            'EdgeColor','none');


        % Create textbox
        annotation(gcf,'textbox',...
            [0.418229166666668,0.547169930312498,0.013281249999999,0.023999999999999],...
            'String','x_{test}',...
            'FontWeight','bold',...
            'FitBoxToText','off',...
            'EdgeColor','none');

        % Create textbox
        annotation(gcf,'textbox',...
            [0.088715841450217,0.694070297958506,0.023268399016965,0.035369775953006],...
            'String','f(x^*)',...
            'FontWeight','bold',...
            'FontSize',12,...
            'FitBoxToText','off',...
            'EdgeColor','none');
    else
        % Find all annotations in the current figure
        annotations = findall(gcf,'Type','annotation');
        delete(annotations);
    end

    hold off
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    subplot(212)
    plot(xtest,phi,'LineWidth',2);
    hold on
    plot(xtest,PHI,'Color',"#D95319",'LineWidth',2);
    plot(xtest,stdv,'--','LineWidth',2);
    plot(xtest,EI,'LineWidth',2)
    fill([0;xtest],[EI;0],'b','FaceAlpha',0.3)
    ylim([-0.8*max(EI) 1.5*max(EI)]);
    xlim([-5 5]);
    title("EI equation components plot")
    xlabel("x")
    ylabel("outputs")
    legend("$\phi$ (z(x))","$\Phi$ (z(x))","$\sigma$ (x)","EI(z(x))",'Interpreter','latex','Location','southeast');
    hold off
    if(i==1)
        pause(5)
    else
        pause(3)
    end
end

full code