HoloML in Stan: Low-photon Image Reconstruction

Brian Ward, Bob Carpenter, and David Barmherzig

June 13, 2022

This case study is a reimplementation of the algorithm described in Barmherzig and Sun (2022) [1] as a Stan model. This requires the new features available in Stan 2.30

Introduction

The HoloML technique is an approach to solving a specific kind of inverse problem inherent to imaging nanoscale specimens using X-ray diffraction.

To solve this problem in Stan, we first write down the forward scientific model given by Barmherzig and Sun, including the Poisson photon distribution and censored data inherent to the physical problem, and then find a solution via penalized maximum likelihood.

Experimental setup

In coherent diffraction imaging (CDI), a radiation source, typically an X-ray, is directed at a biomolecule or other specimen of interest, which causes diffraction. The resulting photon flux is measured by a far-field detector. The expected photon flux is approximately the squared magnitude of the Fourier transform of the electric field causing the diffraction. Inverting this to recover an image of the specimen is a problem usually known as phase retrieval. The phase retrieval problem is highly challenging and often lacks a unique solution [2].

Holographic coherent diffraction imaging (HCDI) is a variant in which the specimen is placed some distance away from a known reference object, and the data observed is the pattern of diffraction around both the specimen and the reference. The addition of a reference object provides additional constraints on this problem, and transforms it into a linear deconvolution problem which has a unique, closed-form solution in the idealized setting [3].

The idealized version of HCDI is formulated as

  • Given a reference $R$, data $Y = | \mathcal{F}( X + R ) | ^2$
  • Recover the source image $X$

Where $\mathcal{F}$ is an oversampled Fourier transform operator.

However, the real-world set up of these experiments introduces two additional difficulties. Data is measured from a limited number of photons, where the number of photons received by each detector is modeled as Poisson distributed with expectation given by $Y_{ij}$ (referred to in the paper as Poisson-shot noise). The expected number of photons each detector receives is denoted $N_p$. We typically have $N_p < 10$ due to the damage that radiation causes the biomolecule under observation. Secondly, to prevent damage to the detectors, the lowest frequencies are removed by a beamstop, which censors low-frequency observations.

The maximum likelihood estimation of the model presented here is able to recover reasonable images even under a regime featuring low photon counts and a beamstop.

Simulating Data

We simulate data from the generative model directly. This corresponds to the approach taken by Barmherzig and Sun, and is based on MATLAB code provided by Barmherzig.

Imports and helper code

Generating the data requires a few standard Python numerical libraries such as scipy and numpy. Matplotlib is also used to simplify loading in the source image and displaying results.

In [1]:
import numpy as np
from scipy import stats
import cmdstanpy

import matplotlib as mpl
import matplotlib.image as mpimg
import matplotlib.pyplot as plt

def rgb2gray(rgb):
    """Convert a nxmx3 RGB array to a grayscale nxm array.

    This function uses the same internal coefficients as MATLAB:
    https://www.mathworks.com/help/matlab/ref/rgb2gray.html
    """
    r, g, b = rgb[:, :, 0], rgb[:, :, 1], rgb[:, :, 2]
    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b

    return gray
Extra display settings set, omitted from rendered case study
Out[2]:

Simulation parameters

To match the figures in the paper (in particular, Figure 9), we use an image of size 256x256, $N_p = 1$ (meaning each detector is expected to receive one photon), and a beamstop of size 25x25 (corresponding to a radius of 13), and a separation d equal to the size of the image.

In [3]:
N = 256
d = N
N_p = 1
r = 13

M1 = 2 * N 
M2 = 2 * (2 * N + d)

We can then load the source image used for these simulations. In this model, the pixels of $X$ grayscale values represented on the interval [0, 1]. A conversion is done here from the standard RGBA encoding using the above rgb2gray function.

The following is a picture of a giant virus known as a mimivirus.

Image credit: Ghigo E, Kartenbeck J, Lien P, Pelkmans L, Capo C, Mege JL, Raoult D., CC BY 2.5, via Wikimedia Commons

In [4]:
X_src = rgb2gray(mpimg.imread('mimivirus.png'))
plt.imshow(X_src, cmap='gray', vmin=0, vmax=1)
Out[4]:
<matplotlib.image.AxesImage at 0x7efe104af4c0>

Additionally, we load in the pattern of the reference object.

The pattern used here is known as a uniformly redundant array (URA) [4]. It has been shown to be an optimal reference image for this kind of work, but other references (including none at all) could be used with the same Stan model.

The code used to generate this grid is omitted from this case study. Various options such as cappy exist to generate these patterns in Python.

In [5]:
R = np.loadtxt('URA.csv', delimiter=",", dtype=int)
plt.imshow(R, cmap='gray')
Out[5]:
<matplotlib.image.AxesImage at 0x7efe083943a0>

We create the specimen-reference hybrid image by concatenating the $X$ image, a matrix of zeros, and the reference $R$. In the true experiment, this is done by placing the specimen some distance d away from the reference, with opaque material between.

This distance is typically the same as the size of the specimen, N. One contribution of the HoloML model is allowing recovery with the reference placed closer to the specimen, and the Stan model allows for this as well.

For this simulation we use the separation of d = N.

In [7]:
X0R = np.concatenate([X_src, np.zeros((N,d)), R], axis=1)
plt.imshow(X0R, cmap='gray')
Out[7]:
<matplotlib.image.AxesImage at 0x7efe083da860>

We can simulate the diffraction pattern of photons from the X-ray by taking the absolute value squared of the 2-dimensional oversampled FFT of this hybrid object.

The oversampled FFT (denoted $\mathcal{F}$ in the paper) corresponds to padding the image in both dimensions with zeros until it is a desired size. For our case, we define the size of the padded image, M1 by M2, to be two times the size of our hybrid image, so the resulting FFT is twice oversampled. This is the oversampling ratio traditionally used for this problem, however Barmherzig and Sun also showed that this model can operate with less oversampling as well.

In [8]:
Y = np.abs(np.fft.fft2(X0R, s=(M1, M2))) ** 2
plt.imshow(np.fft.fftshift(np.log1p(Y)), cmap="viridis")
Out[8]:
<matplotlib.image.AxesImage at 0x7efe06a3d7e0>

We simulate the photon fluxes with a Poisson pseudorandom number generator.

This code specifies a fixed seed to ensure the same fake data is generated each time.

In [9]:
rate = N_p / Y.mean()
Y_tilde = stats.poisson.rvs(rate * Y, random_state=1234)
plt.imshow(np.fft.fftshift(np.log1p(Y_tilde)), cmap="viridis")
Out[9]:
<matplotlib.image.AxesImage at 0x7efe06ab4a30>

Finally, we need to remove the low frequency content of the data. This is caused in the physical experiment by the inclusion of a beamstop, which protects the instrument used by preventing the strongest parts of the beam from directly shining on the detectors.

The beamstop is represented by $\mathcal{B}$, a matrix of 0s and 1s. Zeros indicate that the data is occluded, while ones represent transparent portions.

In [10]:
B_cal = np.ones((M1,M2), dtype=int)
B_cal[M1 // 2 - r + 1: M1 // 2 + r, M2 // 2 - r + 1: M2 // 2 + r] = 0
B_cal = np.fft.ifftshift(B_cal)
# Sanity check
assert (M1 * M2 - B_cal.sum()) == (( 2 * r - 1)**2)
plt.imshow(np.fft.fftshift(B_cal), cmap="gray", vmin=0, vmax=1.25)
Out[10]:
<matplotlib.image.AxesImage at 0x7efe069240a0>

We use this matrix $\mathcal{B}$ to mask the low frequencies of the simulated data. After removing these elements from the simulated data, we have the final input which is used in our model

In [11]:
Y_tilde *= B_cal
plt.imshow(np.fft.fftshift(np.log1p(Y_tilde)), cmap="viridis")
Out[11]:
<matplotlib.image.AxesImage at 0x7efe069624d0>

Stan Model

The Stan model code is a direct translation of the log density of the forward model described in the paper [1] and above. The full model can be seen in the appendix.

Functions

We define two helper functions to implement this model in Stan. The first is a function responsible for generating the $\mathcal{B}$ matrix. Because Stan currently does not have FFT shifting functions, this is done by manually assigning to the corners of the matrix

functions {
  matrix beamstop_gen(int M1, int M2, int r) {
    matrix[M1, M2] B_cal = rep_matrix(1, M1, M2);

    // upper left
    B_cal[1 : r, 1 : r] = rep_matrix(0, r, r);
    // upper right
    B_cal[1 : r, M2 - r + 2 : M2] = rep_matrix(0, r, r - 1);
    // lower left
    B_cal[M1 - r + 2 : M1, 1 : r] = rep_matrix(0, r - 1, r);
    // lower right
    B_cal[M1 - r + 2 : M1, M2 - r + 2 : M2] = rep_matrix(0, r - 1, r - 1);
    return B_cal;
  }

The FFT described in the paper is an oversampled FFT. This corresponds to embedding the image in a larger array of zeros and results in a sort of interpolation between frequencies in the result.

We write an overload of the fft2 function which implements this behavior, similar to the signatures found in Matlab or Python libraries.

complex_matrix fft2(complex_matrix Z, int N, int M) {
    int r = rows(Z);
    int c = cols(Z); 
    complex_matrix[N, M] pad = rep_matrix(0, N, M);
    pad[1 : r, 1 : c] = Z;

    return fft2(pad);
  }
} // end functions block

Note that while the first input of this function is a complex_matrix, it will also accept real matrices due to the built-in type promotion in Stan.

Model inputs

The Stan model needs the same information the generative model did, except it is supplied with $\tilde{Y}$ instead of the source image $X$, plus a scale parameter for the prior, $\sigma$. Smaller values of $\sigma$ (approaching 0) lead to increasing amounts of blur in the resulting image.

data {
  int<lower=0> N;                    // image dimension
  matrix<lower=0, upper=1>[N, N] R;  // reference image
  int<lower=0, upper=N> d;           // separation between sample and reference image
  int<lower=N> M1;                   // rows of padded matrices
  int<lower=2 * N + d> M2;           // cols of padded matrices
  int<lower=0, upper=M1> r;          // beamstop radius. replaces omega1, omega2 in paper

  real<lower=0> N_p;                  // avg number of photons per pixel
  array[M1, M2] int<lower=0> Y_tilde; // observed number of photons

  real<lower=0> sigma;                // standard deviation of pixel prior.
}

The constraints listed above, such as lower=0, perform input validation. For example, the size of the padded FFT is, at a minimum, the size of the hybrid $X0R$ specimen, and we are able to encode this in the model with the lower bounds on M1 and M2.

Additional fixed information

Stan provides the ability to compute transformed data, values which depend on the inputs but only need to be evaluated once per model. This allows us to construct and store $\mathcal{B}$ once, without recomputing it each iteration or requiring it as input.

transformed data {
  matrix[M1, M2] B_cal = beamstop_gen(M1, M2, r);
  matrix[d, N] separation = rep_matrix(0, d, N);
}

Parameters

This model has only one parameter, the image $X$. It is constrained to grayscale values between 0 and 1.

parameters {
  matrix<lower=0, upper=1>[N, N] X;
}

Model code

Priors

We add a prior on $X$ to impose an L2 penalty on adjacent pixels. This induces a Gaussian blur on the result, and it is not strictly necessary for running the model.

This prior is coded in our Stan program by looping over the rows and columns and using a vectorized call to the normal distribution. This results in each pixel being adjacent to 4 others. One could also formulate a prior which includes diagonally adjacent pixels

model {
  for (i in 1 : rows(X) - 1) {
    X[i] ~ normal(X[i + 1], sigma);
  }
  for (j in 1 : cols(X) - 1) {
    X[ : , j] ~ normal(X[ : , j + 1], sigma);
  }

Likelihood

The model likelihood encodes the forward model. We construct the hybrid specimen, compute $|\mathcal{F}(X0R)|^2$, and then compute the rate $\lambda$ by scaling by the average number of photons $N_p$.

We then loop over this result. If the current indices are not occluded by the beamstop $\mathcal{B}$, we say that the data $\tilde{Y}$ is distributed by a Poisson distribution with $\lambda$ as the rate parameter.

// object representing specimen and reference together
  matrix[N, 2 * N + d] X0R = append_col(X, append_col(separation, R));
  // signal - squared magnitude of the (oversampled) FFT
  matrix[M1, M2] Y = abs(fft2(X0R, M1, M2)) .^ 2;

  real N_p_over_Y_bar = N_p / mean(Y);
  matrix[M1, M2] lambda = N_p_over_Y_bar * Y;

  for (m1 in 1 : M1) {
    for (m2 in 1 : M2) {
      if (B_cal[m1, m2]) {
        Y_tilde[m1, m2] ~ poisson(lambda[m1, m2]);
      }
    }
  }
} // end model block

Optimization

Now that we have our simulated data and our generative model, we solve the inverse problem.

Data preparation

We prepare a dictionary of data corresponding to the models data block. This is mostly reusing constants defined earlier for the data simulation.

In [12]:
sigma = 1 # prior smoothing
data = {
    "N": N,
    "R": R,
    "d": N,
    "M1": M1,
    "M2": M2,
    "Y_tilde": Y_tilde,
    "r": r,
    "N_p": N_p,
    "sigma": sigma
}

To run the model from Python, we instantiate it as a CmdStanModel object from cmdstanpy.

In [13]:
HoloML_model = cmdstanpy.CmdStanModel(stan_file="./holoml.stan")
INFO:cmdstanpy:found newer exe file, not recompiling

Here we use optimization via the limited-memory quasi-Newton L-BFGS algorithm. This method has a bit more curvature information than what is available to the conjugate gradient approach, but less than the second order trust-region method used in the paper. This should take a few (1-3) minutes, depending on the machine you are running on.

It is also possible to sample the model using the No-U-Turn Sampler (NUTS), but evaluations of this are out of the scope of this case study.

In [14]:
%time fit = HoloML_model.optimize(data, inits=1, seed=5678)
INFO:cmdstanpy:Chain [1] start processing
INFO:cmdstanpy:Chain [1] done processing
CPU times: user 241 ms, sys: 41.8 ms, total: 283 ms
Wall time: 2min 39s

We use the function stan_variable to extract the maximum likelihood estimate (MLE) from the fit object returned by optimization.

We can use this to plot the recovered image alongside the original.

In [15]:
fig = plt.figure()

ax1 = fig.add_subplot(1, 4, 1, title="Source Image")
ax1.imshow(X_src, cmap="gray", vmin=0, vmax=1)

ax2 = fig.add_subplot(1, 4, 2, title="Recovered Image")
ax2.imshow(fit.stan_variable("X"), cmap="gray", vmin=0, vmax=1)
Out[15]:
<matplotlib.image.AxesImage at 0x7efe05c3f880>

Varying $N_p$

The above selection of $N_p=1$ is a reasonable choice for real experiment, but both smaller and larger numbers of expected photons may be used. The following are results for two other levels, $N_p = 0.1$ and $N_p = 10$

This requires repeating the final few steps of the data generation and then re-fitting the model accordingly.

In [16]:
N_p = 0.1

Y_tilde = stats.poisson.rvs((N_p / Y.mean()) * Y, random_state=1234) * B_cal

data_fewer_photons = data.copy()
data_fewer_photons['N_p'] = N_p
data_fewer_photons['Y_tilde'] = Y_tilde

%time fit_fewer_photons = HoloML_model.optimize(data_fewer_photons, inits=1, seed=5678)
INFO:cmdstanpy:Chain [1] start processing
INFO:cmdstanpy:Chain [1] done processing
CPU times: user 328 ms, sys: 47.9 ms, total: 376 ms
Wall time: 3min 5s
In [17]:
N_p = 10

Y_tilde = stats.poisson.rvs((N_p / Y.mean()) * Y, random_state=1234) * B_cal

data_more_photons = data.copy()
data_more_photons['N_p'] = N_p
data_more_photons['Y_tilde'] = Y_tilde

%time fit_more_photons = HoloML_model.optimize(data_more_photons, inits=1, seed=5678)
INFO:cmdstanpy:Chain [1] start processing
INFO:cmdstanpy:Chain [1] done processing
CPU times: user 249 ms, sys: 24.5 ms, total: 274 ms
Wall time: 1min 27s

It is worth noting that these two optimizations take very different amounts of time compared to the original, as the differing amounts of data yield posteriors which are more or less normal.

In addition to the difference in runtime, the resulting images are very different.

In [18]:
fig = plt.figure()

ax1 = fig.add_subplot(1, 4, 1, title="Source Image")
ax1.imshow(X_src, cmap="gray", vmin=0, vmax=1)

ax2 = fig.add_subplot(1, 4, 2, title="Recovered Image\n($N_p=10$)")
ax2.imshow(fit_more_photons.stan_variable("X"), cmap="gray", vmin=0, vmax=1)

ax3 = fig.add_subplot(1, 4, 3, title="Recovered Image\n($N_p=1$)")
ax3.imshow(fit.stan_variable("X"), cmap="gray", vmin=0, vmax=1)

ax4 = fig.add_subplot(1, 4, 4, title="Recovered Image\n($N_p=0.1$)")
ax4.imshow(fit_fewer_photons.stan_variable("X"), cmap="gray", vmin=0, vmax=1)
Out[18]:
<matplotlib.image.AxesImage at 0x7efe047f5a80>

Prior tuning

The above choice of $\sigma = 1$ has a very slight effect on the output image.

We also show the recovered image for $\sigma = 20$, which provides even less smoothing than the above, and for $\sigma = 0.05$. This smaller value imposes a greater penalty on adjacent pixels which are significantly different than each other, smoothing out the result.

Each of these is done with the original value of $N_p = 1$

In [19]:
data_weaker_prior = data.copy()
data_weaker_prior['sigma'] = 20

%time fit_rougher = HoloML_model.optimize(data_weaker_prior, inits=1, seed=5678)
INFO:cmdstanpy:Chain [1] start processing
INFO:cmdstanpy:Chain [1] done processing
CPU times: user 243 ms, sys: 33.3 ms, total: 277 ms
Wall time: 3min 15s
In [20]:
data_stronger_prior = data.copy()
data_stronger_prior['sigma'] = 0.05

%time fit_smooth = HoloML_model.optimize(data_stronger_prior, inits=1, seed=5678)
INFO:cmdstanpy:Chain [1] start processing
INFO:cmdstanpy:Chain [1] done processing
CPU times: user 237 ms, sys: 38.6 ms, total: 276 ms
Wall time: 2min 32s
In [21]:
fig = plt.figure()

ax1 = fig.add_subplot(1, 4, 1, title="Source Image")
ax1.imshow(X_src, cmap="gray", vmin=0, vmax=1)

ax2 = fig.add_subplot(1, 4, 2, title="Recovered Image\n($\sigma=0.05$)")
ax2.imshow(fit_smooth.stan_variable("X"), cmap="gray", vmin=0, vmax=1)

ax3 = fig.add_subplot(1, 4, 3, title="Recovered Image\n($\sigma=1$)")
ax3.imshow(fit.stan_variable("X"), cmap="gray", vmin=0, vmax=1)

ax4 = fig.add_subplot(1, 4, 4, title="Recovered Image\n($\sigma=20$)")
ax4.imshow(fit_rougher.stan_variable("X"), cmap="gray", vmin=0, vmax=1)
Out[21]:
<matplotlib.image.AxesImage at 0x7efdfef2ce20>

References

[1] Barmherzig, D. A., & Sun, J. (2022). Towards practical holographic coherent diffraction imaging via maximum likelihood estimation. Opt. Express, 30(5), 6886–6906. doi:10.1364/OE.445015

[2] Barnett, A. H., Epstein, C. L., Greengard, L. F., & Magland, J. F. (2020). Geometry of the phase retrieval problem. Inverse Problems, 36(9), 094003. doi:10.1088/1361-6420/aba5ed

[3] Barmherzig, D. A., Sun, J., Li, P.-N., Lane, T. J., & Candès, E. J. (2019). Holographic phase retrieval and reference design. Inverse Problems, 35(9), 094001. doi:10.1088/1361-6420/ab23d1

[4] Fenimore, E. E., & Cannon, T. M. (1978). Coded aperture imaging with uniformly redundant arrays. Appl. Opt., 17(3), 337–347. doi:10.1364/AO.17.000337

Appendix: Full Stan Code

Out[22]:
functions {
  /**
   * Return M1 x M2 matrix of 1 values with blocks in corners set to
   * 0, where the upper left is (r x r), the upper right is (r x r-1),
   * the lower left is (r-1 x r), and the lower right is (r-1 x r-1).
   * This corresponds to zeroing out the lowest-frequency portions of
   * an FFT.
   * @param M1 number of rows
   * @param M2 number of cols
   * @param r block dimension
   * @return matrix of 1 values with 0-padded corners
   */
  matrix beamstop_gen(int M1, int M2, int r) {
    matrix[M1, M2] B_cal = rep_matrix(1, M1, M2);
    if (r == 0) {
      return B_cal;
    }
    // upper left
    B_cal[1 : r, 1 : r] = rep_matrix(0, r, r);
    // upper right
    B_cal[1 : r, M2 - r + 2 : M2] = rep_matrix(0, r, r - 1);
    // lower left
    B_cal[M1 - r + 2 : M1, 1 : r] = rep_matrix(0, r - 1, r);
    // lower right
    B_cal[M1 - r + 2 : M1, M2 - r + 2 : M2] = rep_matrix(0, r - 1, r - 1);
    return B_cal;
  }
  
  /**
   * Return the matrix corresponding to the fast Fourier
   * transform of Z after it is padded with zeros to size
   * N by M
   * When N by M is larger than the dimensions of Z,
   * this computes an oversampled FFT.
   *
   * @param Z matrix of values
   * @param N number of rows desired (must be >= rows(Z))
   * @param M number of columns desired (must be >= cols(Z))
   * @return the FFT of Z padded with zeros
   */
  complex_matrix fft2(complex_matrix Z, int N, int M) {
    int r = rows(Z);
    int c = cols(Z);
    if (r > N) {
      reject("N must be at least rows(Z)");
    }
    if (c > M) {
      reject("M must be at least cols(Z)");
    }
    
    complex_matrix[N, M] pad = rep_matrix(0, N, M);
    pad[1 : r, 1 : c] = Z;
    
    return fft2(pad);
  }
}
data {
  int<lower=0> N; // image dimension
  matrix<lower=0, upper=1>[N, N] R; // registration image
  int<lower=0, upper=N> d; // separation between sample and registration image
  int<lower=N> M1; // rows of padded matrices
  int<lower=2 * N + d> M2; // cols of padded matrices
  int<lower=0, upper=M1> r; // beamstop radius. replaces omega1, omega2 in paper
  
  real<lower=0> N_p; // avg number of photons per pixel
  array[M1, M2] int<lower=0> Y_tilde; // observed number of photons
  
  real<lower=0> sigma; // standard deviation of pixel prior.
}
transformed data {
  matrix[M1, M2] B_cal = beamstop_gen(M1, M2, r);
  matrix[d, N] separation = rep_matrix(0, d, N);
}
parameters {
  matrix<lower=0, upper=1>[N, N] X;
}
model {
  // prior - penalizing L2 on adjacent pixels
  for (i in 1 : rows(X) - 1) {
    X[i] ~ normal(X[i + 1], sigma);
  }
  for (j in 1 : cols(X) - 1) {
    X[ : , j] ~ normal(X[ : , j + 1], sigma);
  }
  
  // likelihood
  // object representing specimen and reference together
  matrix[N, 2 * N + d] X0R = append_col(X, append_col(separation, R));
  // signal - squared magnitude of the (oversampled) FFT
  matrix[M1, M2] Y = abs(fft2(X0R, M1, M2)) .^ 2;
  
  real N_p_over_Y_bar = N_p / mean(Y);
  matrix[M1, M2] lambda = N_p_over_Y_bar * Y;
  
  for (m1 in 1 : M1) {
    for (m2 in 1 : M2) {
      if (B_cal[m1, m2]) {
        Y_tilde[m1, m2] ~ poisson(lambda[m1, m2]);
      }
    }
  }
}

Digression: Efficiency

The model above is coded for readability and sticks closely to the mathematical formulation of the process. However, this does lead to an inefficient condition inside the tightest loop of the model to handle the beamstop occlusion.

In practice, it is possible to avoid this conditional by changing how the data is stored. Instead of storing the beamstop occlusion as a parallel matrix, we can pre-compute the list of indices which are included once and store it. Then, we can create flat representations of both the data $\tilde{Y}$ and the rate $\lambda$, allowing us to use a vectorized version of the Poisson distribution.

transformed data {
  array[M1, M2] int B_cal = beamstop_gen(M1, M2, r);
  int total = sum(to_array_1d(B_cal));
  array[total, 2] idxs;
  // pre-compute indices
  int current = 1;
  for (n in 1:M1){
    for (m in 1:M2){
      if (B_cal[n, m]){
        idxs[current, :] = {n,m};
        current += 1;
      }
    }
  }
  // flatten data accordingly
  array[total] int<lower=0> Ys;
  for (n in 1:total) {
    Ys[n] = Y_tilde[idxs[n, 1], idxs[n, 2]];
  }
}
model {
  // ... same code for computing matrix[M1, M2] lambda here
  array[total] real lambdas;
  for (n in 1:total) {
    lambdas[n] = lambda[idxs[n, i], idxs[n, j]];  // much cheaper than branching
  }

  Ys ~ poisson(lambdas);  // fully vectorized
}

This formulation of the model reduces the amount of time per gradient evaluation by 15-20%. A brief evaluation suggests however that the impact on optimization runtime is minimal.

Reproducibility

This notebook's source and related materials are available at https://github.com/WardBrian/holoml-in-stan.

The following versions were used to produce this page:

In [23]:
%load_ext watermark
%watermark -n -u -v -iv -w
print("CmdStan:", cmdstanpy.utils.cmdstan_version())
Last updated: Fri Jul 01 2022

Python implementation: CPython
Python version       : 3.10.4
IPython version      : 8.4.0

scipy     : 1.8.0
cmdstanpy : 1.0.1
numpy     : 1.22.3
IPython   : 8.4.0
matplotlib: 3.5.2

Watermark: 2.3.1

CmdStan: (2, 30)

The rendered HTML output is produced with

jupyter nbconvert --to html "HoloML in Stan.ipynb" --template classic --TagRemovePreprocessor.remove_input_tags=hide-code -CSSHTMLHeaderPreprocessor.style=tango --execute