操作

操作类

class BinaryOp(*args, **kwargs)[source]

基类: Op

arity = 2
class FinitaryOp(*args, **kwargs)[source]

基类: Op

arity = 1
class LogAbsDetJacobianOp(*args, **kwargs)

基类: BinaryOp

static default(x, y, fn)
dispatcher = <dispatched log_abs_det_jacobian>
name = 'log_abs_det_jacobian'
signature = <Signature (x, y, fn)>
class NullaryOp(*args, **kwargs)[source]

基类: Op

arity = 0
class Op(*args, **kwargs)[source]

基类: object

所有基项数学操作的抽象基类。

操作(Ops)接受 arity 个左侧位置参数,这些参数可以是 funsor,后面可以跟额外的非 funsor 参数和关键字参数。这些额外的参数和关键字参数必须有默认值。

封装新的后端操作时,请记住这些限制,这可能要求您在将后端函数转换为操作之前对其进行封装

  • 只能通过使用 @UnaryOp.make, @BinaryOp.make 等装饰器为默认实现创建新的操作。

  • 对于 arity 为 1、2 等的情况,可以通过 @my_op.register(type1), @my_op.register(type1, type2) 等注册后端特定的实现。模式只能包含前 arity 个类型。

  • 只有前 arity 个参数可以是 funsors。剩余的参数和关键字参数必须都是基础 Python 数据。

变量

~.arity (int) – 此操作接受的 funsor 参数数量。必须由子类定义。

参数
  • *args

  • **kwargs – 此操作的所有额外参数,不包括 .arity 之前的参数,

arity = NotImplemented
register(*pattern)
classmethod subclass_register(*pattern)[source]
classmethod make(fn=None, *, name=None, metaclass=None, module_name='funsor.ops')[source]

用于创建新的 Op 子类及其新的默认实例的工厂函数。

参数

fn (callable) – 一个可以检查其签名的函数。

返回值

新的默认实例。

返回类型

Op

class TernaryOp(*args, **kwargs)[source]

基类: Op

arity = 3
class TransformOp(*args, **kwargs)[source]

基类: UnaryOp

set_inv(fn)[source]
参数

fn (callable) – 一个输入参数 y 并输出值 x 的函数,使得 y=self(x)

set_log_abs_det_jacobian(fn)[source]
参数

fn (callable) – 一个输入两个参数 x, y 的函数,其中 y=self(x),并返回 log(abs(det(dy/dx)))

static inv(x)[source]
static log_abs_det_jacobian(x, y)[source]
class UnaryOp(*args, **kwargs)[source]

基类: Op

arity = 1
class WrappedTransformOp(*args, **kwargs)

基类: TransformOp

后端 Transform 对象的包装器,提供 .inv.log_abs_det_jacobian 方法。它还会在第一次调用 __call__() 时验证形状。

static default(x, fn, *, validate_args=True)

后端 Transform 对象的包装器,提供 .inv.log_abs_det_jacobian 方法。它还会在第一次调用 __call__() 时验证形状。

dispatcher = <dispatched wrapped_transform>
property inv
property log_abs_det_jacobian
name = 'wrapped_transform'
signature = <Signature (x, fn, *, validate_args=True)>
declare_op_types(locals_, all_, name_)[source]

内置操作

abs = ops.abs

返回参数的绝对值。

add = ops.add

等同于 a + b。

and_ = ops.and_

等同于 a & b。

atanh = ops.atanh

返回 x 的反双曲正切。

eq = ops.eq

等同于 a == b。

exp = ops.exp

返回 e 的 x 次幂。

floordiv = ops.floordiv

等同于 a // b。

ge = ops.ge

等同于 a >= b。

getitem = ops.getitem
getslice = ops.getslice
gt = ops.gt

等同于 a > b。

invert = ops.invert

等同于 ~a。

le = ops.le

等同于 a <= b。

lgamma = ops.lgamma

Gamma 函数在 x 处的绝对值的自然对数。

log = ops.log
log1p = ops.log1p

返回 1+x 的自然对数(底数为 e)。

计算结果在 x 接近零时精确。

lshift = ops.lshift

等同于 a << b。

lt = ops.lt

等同于 a < b。

matmul = ops.matmul

等同于 a @ b。

max = ops.max
min = ops.min
mod = ops.mod

等同于 a % b。

mul = ops.mul

等同于 a * b。

ne = ops.ne

等同于 a != b。

neg = ops.neg

等同于 -a。

null = ops.null

一个占位符结合操作,可与任何其他操作统一

or_ = ops.or_

等同于 a | b。

pos = ops.pos

等同于 +a。

pow = ops.pow

等同于 a ** b。

reciprocal = ops.reciprocal
rshift = ops.rshift

等同于 a >> b。

safediv = ops.safediv
safesub = ops.safesub
sigmoid = ops.sigmoid
sqrt = ops.sqrt

返回 x 的平方根。

sub = ops.sub

等同于 a - b。

tanh = ops.tanh

返回 x 的双曲正切。

truediv = ops.truediv

等同于 a / b。

xor = ops.xor

等同于 a ^ b。

数组操作

all = ops.all
amax = ops.amax
amin = ops.amin
any = ops.any
argmax = ops.argmax
argmin = ops.argmin
astype = ops.astype
cat = ops.cat
cholesky = ops.cholesky

类似于 numpy.linalg.cholesky(),但对标量矩阵使用 sqrt。

cholesky_inverse = ops.cholesky_inverse

类似于 torch.cholesky_inverse(),但支持批处理和梯度。

cholesky_solve = ops.cholesky_solve
clamp = ops.clamp
detach = ops.detach
diagonal = ops.diagonal
einsum = ops.einsum
expand = ops.expand
finfo = ops.finfo
flip = ops.flip
full_like = ops.full_like
isnan = ops.isnan
logaddexp = ops.logaddexp
logsumexp = ops.logsumexp
mean = ops.mean
new_arange = ops.new_arange
new_eye = ops.new_eye
new_full = ops.new_full
new_zeros = ops.new_zeros
permute = ops.permute
prod = ops.prod
qr = ops.qr
randn = ops.randn
sample = ops.sample
scatter = ops.scatter
scatter_add = ops.scatter_add
stack = ops.stack
std = ops.std
sum = ops.sum
transpose = ops.transpose
triangular_inv = ops.triangular_inv
triangular_solve = ops.triangular_solve
unsqueeze = ops.unsqueeze
var = ops.var