使用 Funsor 的方法

本模块提供了许多使用 Funsor 的高级算法。

forward_filter_backward_rsample(factors: Dict[str, Funsor], eliminate: FrozenSet[str], plates: FrozenSet[str], sample_inputs: Dict[str, type] = {}, rng_key=None)[source]

一种前向-滤波后向批量重新参数化采样算法,用于变分推断。主要的用例是对结构化变分后验执行高斯张量变量消除。

参数
  • factors (dict) – 将采样站点名称映射到在该采样站点创建的 Funsor factor 的字典。

  • frozenset – 要边缘化的潜变量名称集合以及要聚合的 plate 名称集合。

  • plates – 要聚合的 plate 名称集合。

  • sample_inputs (dict) – 可选的外部采样索引字典,将根据这些索引批量绘制样本。

  • rng_key – JAX 后端的随机数 key。

返回值

一个样本对 samples:Dict[str, Tensor], log_prob: Tensor,包含样本及其在每个样本处的对数密度。如果 sample_inputs 非空,则两个输出都将是批量的。

返回类型

tuple

forward_filter_backward_precondition(factors: Dict[str, Funsor], eliminate: FrozenSet[str], plates: FrozenSet[str], aux_name: str = 'aux')[source]

一种前向-滤波后向预处理算法,用于变分推断或哈密顿蒙特卡洛中的预处理。主要的用例是对结构化变分后验执行高斯张量变量消除,并可选地使用学习到的后验来确定 HMC 中的动量。

参数
  • factors (dict) – 将采样站点名称映射到在该采样站点创建的 Funsor factor 的字典。

  • frozenset – 要边缘化的潜变量名称集合以及要聚合的 plate 名称集合。

  • plates – 要聚合的 plate 名称集合。

  • aux_name (str) – 包含白噪声的辅助变量的名称。

返回值

一个样本对 samples:Dict[str, Tensor], log_prob: Tensor,包含样本及其在每个样本处的对数密度。这两个输出都取决于一个由 aux_name 命名的向量,例如 aux: Reals[d],其中 d 是被消除变量中的元素总数。

返回类型

tuple