Source code for conditional_independence.ci_tests.nonparametric.hsic

from typing import Union, List, Dict
import numpy as np
from ...utils.kernels import rbf_kernel, center_fast_mutate
from scipy.stats import gamma
from ...utils import residuals
from ...utils import to_list
from scipy.special import gdtr
import ipdb
import numexpr as ne


def hsic_test_vector(
        x: np.ndarray,
        y: np.ndarray,
        sig: float=1/np.sqrt(2),
        alpha=0.05
) -> Dict:
    """
    Test for independence of X and Y using the Hilbert-Schmidt Information Criterion.

    Parameters
    ----------
    x:
        vector of samples from X.
    y:
        vector of samples from Y.
    sig:
        width parameter.
    alpha:
        significance level.

    Returns
    -------

    """
    alpha_test = alpha
    if x.ndim == 1:
        x = x.reshape((len(x), 1))
    if y.ndim == 1:
        y = y.reshape((len(y), 1))
    n = x.shape[0]
    if y.shape[0] != n:
        raise ValueError("Y should have the same number of samples as X")

    n = x.shape[0]
    kernel_precision = 1/(sig**2)

    # === COMPUTE CENTRALIZED KERNEL MATRICES
    kx = rbf_kernel(x, kernel_precision)
    ky = rbf_kernel(y, kernel_precision)
    kx_off_diag_sum = kx.sum() - kx.trace()
    ky_off_diag_sum = ky.sum() - ky.trace()
    kx_centered = center_fast_mutate(kx)
    ky_centered = center_fast_mutate(ky)

    # === COMPUTE STATISTIC
    statistic = 1/n**2 * ne.evaluate('sum(a * b)', {'a': kx_centered, 'b': ky_centered})  # SAME AS trace(kx_centered @ ky_centered)

    # Theorem 3
    mu_x = 1/(n*(n-1)) * kx_off_diag_sum
    mu_y = 1/(n*(n-1)) * ky_off_diag_sum
    mean_approx = 1/n * (1 + mu_x*mu_y - mu_x - mu_y)
    # Theorem 4
    var_coef = 2*(n-4)*(n-5)/(n*(n-1)*(n-2)*(n-3))
    B = (kx_centered * ky_centered)**2
    var_approx = var_coef * (B.sum() - np.trace(B)) / n**2

    alpha = mean_approx ** 2 / var_approx
    beta = var_approx / mean_approx

    p_value = 1 - gdtr(1/beta, alpha, statistic)

    return dict(
        statistic=statistic,
        p_value=p_value,
        reject=p_value < alpha_test,
        mean_approx=mean_approx,
        var_approx=var_approx
    )


[docs]def hsic_test( suffstat: np.ndarray, i: int, j: int, cond_set: Union[List[int], int]=None, alpha: float=0.05 ) -> Dict: """ Test for (conditional) independence using the Hilbert-Schmidt Information Criterion. If a conditioning set is specified, first perform non-parametric regression, then test residuals. Parameters ---------- suffstat: Matrix of samples. i: column position of first variable. j: column position of second variable. cond_set: column positions of conditioning set. alpha: Significance level of the test. Returns ------- """ cond_set = to_list(cond_set) if len(cond_set) == 0: return hsic_test_vector(suffstat[:, i], suffstat[:, j], alpha=alpha) else: residuals_i, residuals_j = residuals(suffstat, i, j, cond_set) return hsic_test_vector(residuals_i, residuals_j, alpha=alpha)