Chapter 8 BART

8.1 A BART version of our hierachical trees model

Let’s define:

  • P to be the number of trees
  • J to be the total number of groups
  • \(\Theta\) will be the set of node hyperparameters
    • \(\mu\) and \(\mu_j\) for each tree in 1 to P

We have a variable of interest for which we assume:

\[\begin{equation} y_{ij} = \sum_{p = 1}^{P} \overbrace{\mathbb{G}}^\text{Tree look up function}(\underbrace{X_{ij}}_\text{Covariates}, \overbrace{T_{p}}^\text{Tree structure}, \overbrace{\Theta_{p}}^\text{Terminal node parameters}) + \underbrace{\epsilon_{ij}}_\text{Noise} \end{equation}\]

for observation \(i = i, \dots, n_j\) in group \(j = 1, \dots, J\). We also have that:

\[\begin{equation} \epsilon_{ij} \sim N(0, \tau^{-1}), \end{equation}\]

where \(\tau^{-1}\) is the residual precision. In this setting, \(\Theta_{p}\) will represent the terminal node parameters + the individual group parameters for tree \(p\).

For a single terminal node, let:

\[\begin{equation} R_{ijp1} = Y_{ij}^{(1)} - \sum_{t \neq p} \mathbb{G}(X_{ij}^{(1)}, T_{t}, M_{t}) \end{equation}\]

which represents the partial residuals for observation i, in group j, for tree p in terminal node 1. Now, let

\[\begin{equation} \underset{\sim}{R_j} = \{R_{ij}, \dots, j = 1,\dots, J \} \end{equation}\]

then

\[\begin{equation} \underset{\sim}{R_j} \sim N(\mu_j, \tau^{-1}), \\ \\ \mu_{jpl} \sim N(\mu, k_1\tau^{-1}/P),\\ \text{where P = number of trees, p = tree index, j = group index, l = terminal node index} \\ \\ \mu_{pl} \sim N(0, k_2 \tau^{-1}/P)\\ \text{where P = number of trees, p = tree index, l = terminal node index} \\ \end{equation}\]

with \(l = 1, \dots, n_{p}\), where \(n_{p}\) is the number of nodes in tree p, and \(\sum_{p = 1}^{P} n_p = N_p\).

Using the same marginalisation as for a single tree:

\[\begin{equation} \underset{\sim}{R_j} \sim MVN(\mu \mathbf{1}, \tau^{-1} (k_1MM^{T} + \mathbb{I})), \text{(M = group model matrix)}\\ \text{using the same trick as before and } \Psi = k_1 MM^{T} + \mathbb{I}: \\ \underset{\sim}{R_j} \sim MVN(0, \tau^{-1} (\Psi + k_2 \mathbf{1}\mathbf{1}^{T})), \end{equation}\]

which is used to get the marginal distribution of a new tree. The new posterior updates will be:

\[\begin{equation} \mu | \dots \sim N( \frac{\mathbf{1}^{T} \Psi^{-1} R }{\mathbf{1}^{T} \Psi^{-1} \mathbf{1} + (k_2/P)^{-1}}, \tau^{-1} (\mathbf{1}^{T} \Psi^{-1} \mathbf{1} + (k_2/P)^{-1})), \end{equation}\]

\[\begin{equation} \mu_j | \dots \sim MVN( \frac{P \mu /k_1 + \bar R_j n_j}{(n_j + P/k_1)}, \tau^{-1} (n_j + P/k_1)) \end{equation}\]

The update for \(\tau\) will be a little different. Let \(\hat f_{ij}\) be the overall prediction for observation \(R_{ij}\) at the current iteration, which is the sum of group parameters for the corresponding observation. Then:

\[\begin{equation} \pi(\tau | \dots) \propto \Big[ \Pi_{i = 1}^{N} \pi(y_i | \tau) \Big] \times \Big[ \Pi_{j, l, p} \pi(\mu_{j, l, p} | \tau) \Big] \times \Big[ \Pi_{l, p} \pi(\mu_{l, p} | \tau) \Big] \times \pi(\tau) \\ \propto \Big[ \tau^{N/2} \exp\{-\tau \frac{\sum_{i= 1}^{N}(y_i - \hat f_i)^2}{2} \} \Big] \\ \times \Big[ (\frac{\tau P}{k_1})^{(J N_p)/2} \exp\{-(\frac{\tau P}{k_1}) \frac{\sum_{j, l, p}(\mu_{j, l, p} - \mu_{l, p})^2}{2} \} \Big] \\ \times \Big[ (\frac{\tau P}{k_2})^{N_p/2} \exp\{-(\frac{\tau P}{k_2}) \frac{\sum_{l, p}\mu_{l, p}^2}{2} \} \Big] \times \tau^{\alpha - 1} \exp\{-\tau \beta \} \\ \\ \end{equation}\]

\[\begin{equation} \tau | \dots \sim Ga( \frac{N + J N_p + N_p}{2} + \alpha, \\ \frac{\sum_{i= 1}^{N}(y_i - \hat f_i)^2}{2} + \frac{P \sum_{j, l, p}(\mu_{j, l, p} - \mu_{l, p})^2}{2 k_1} + \frac{P \sum_{l, p}\mu_{l, p}^2}{2 k_2} + \beta ) \end{equation}\]