Source code for robokit.lie.so3

import abc
from typing import TYPE_CHECKING, Dict, Literal, Optional, Set, Tuple, Union, overload

import numpy as np
import pinocchio as pin

from robokit import CONFIG
from robokit.types import ArrayLike


try:
    import torch  # torch is optional

    _TORCH_AVAILABLE = True
except ImportError:
    _TORCH_AVAILABLE = False


try:
    import warp as wp  # warp is optional

    _WARP_AVAILABLE = True
except ImportError:
    _WARP_AVAILABLE = False

if TYPE_CHECKING:
    from robokit.lie.pinocchio_so3 import PinocchioSO3
    from robokit.lie.torch_so3 import TorchSO3
    from robokit.lie.warp_so3 import WarpSO3


[docs] class SO3(abc.ABC): """SO(3) representation using quaternions. Internal parameterization: [qw, qx, qy, qz]. Tangent parameterization: [omega_x, omega_y, omega_z]. """ @staticmethod def _infer_array_type( param: Union[pin.Quaternion, ArrayLike], array_type: Optional[Literal["numpy", "torch", "warp"]] = None, ) -> Literal["numpy", "torch", "warp"]: if array_type is None: if isinstance(param, (np.ndarray, pin.Quaternion)): array_type = "numpy" elif _TORCH_AVAILABLE and isinstance(param, torch.Tensor): array_type = "torch" elif _WARP_AVAILABLE and isinstance(param, wp.array): array_type = "warp" else: raise ValueError( f"Cannot infer array_type from type: {type(param)}. Expected pin.Quaternion, numpy.ndarray, torch.Tensor, or wp.array." ) return array_type @staticmethod def _get_so3_class( array_type: Literal["numpy", "torch", "warp"], ) -> Union["type[PinocchioSO3]", "type[TorchSO3]", "type[WarpSO3]"]: # fmt: off if array_type == "numpy": from robokit.lie.pinocchio_so3 import PinocchioSO3 return PinocchioSO3 elif array_type == "torch": from robokit.lie.torch_so3 import TorchSO3 return TorchSO3 elif array_type == "warp": from robokit.lie.warp_so3 import WarpSO3 return WarpSO3 else: raise ValueError(f"Unsupported array_type: {array_type}") # fmt: on @staticmethod def _validate_backends( array_type: Literal["numpy", "torch", "warp"], compute_backend: Optional[Literal["pinocchio", "torch", "warp"]], ) -> Literal["pinocchio", "torch", "warp"]: backend_config: Dict[str, Tuple[Literal["pinocchio", "torch", "warp"], Set[str]]] = { "numpy": ("pinocchio", {"pinocchio"}), "torch": (CONFIG.default_torch_compute_backend, {"torch", "warp"}), "warp": ("warp", {"warp"}), } default, valid = backend_config[array_type] if compute_backend is None: return default if compute_backend not in valid: raise ValueError( f"array_type='{array_type}' cannot use compute_backend='{compute_backend}'. Valid options: {valid}" ) return compute_backend @overload def __new__( cls, so3_like: "wp.array", array_type: Literal["warp"], compute_backend: Optional[Literal["warp"]] = ... ) -> "WarpSO3": ... @overload def __new__( cls, so3_like: "torch.Tensor", array_type: Optional[Literal["torch"]] = ..., compute_backend: Optional[Literal["torch", "warp"]] = ..., ) -> "TorchSO3": ... @overload def __new__( cls, so3_like: Union[pin.Quaternion, np.ndarray], array_type: Optional[Literal["numpy"]] = ..., compute_backend: Optional[Literal["pinocchio"]] = ..., ) -> "PinocchioSO3": ... def __new__( cls, so3_like: Union[pin.Quaternion, ArrayLike], array_type: Optional[Literal["numpy", "torch", "warp"]] = None, compute_backend: Optional[Literal["pinocchio", "torch", "warp"]] = None, ) -> "SO3": # If so3_like is already a SO3 instance, return it directly if isinstance(so3_like, SO3): return so3_like # If called directly on SO3, determine the array_type and return appropriate subclass if cls is SO3: array_type = cls._infer_array_type(so3_like, array_type) compute_backend = cls._validate_backends(array_type, compute_backend) so3_cls = cls._get_so3_class(array_type) return so3_cls.__new__(so3_cls, so3_like, array_type, compute_backend) # type: ignore else: # Called on subclass, use normal instantiation return super().__new__(cls) @property @abc.abstractmethod def wxyz(self) -> Union[np.ndarray, "torch.Tensor"]: """Quaternion part of SO(3) as [qw, qx, qy, qz]""" ... @staticmethod def from_matrix( matrix: Union[np.ndarray, "torch.Tensor"], array_type: Optional[Literal["numpy", "torch", "warp"]] = None, compute_backend: Optional[Literal["pinocchio", "torch", "warp"]] = None, ) -> "SO3": array_type = SO3._infer_array_type(matrix, array_type) compute_backend = SO3._validate_backends(array_type, compute_backend) so3_cls = SO3._get_so3_class(array_type) return so3_cls.from_matrix(matrix, array_type, compute_backend) # pyright: ignore[reportArgumentType] def as_matrix(self): raise NotImplementedError @staticmethod def exp( log_rot: Union[np.ndarray, "torch.Tensor"], array_type: Optional[Literal["numpy", "torch", "warp"]] = None, compute_backend: Optional[Literal["pinocchio", "torch", "warp"]] = None, ) -> "SO3": array_type = SO3._infer_array_type(log_rot, array_type) compute_backend = SO3._validate_backends(array_type, compute_backend) so3_cls = SO3._get_so3_class(array_type) return so3_cls.exp(log_rot, array_type, compute_backend) # pyright: ignore[reportArgumentType]
[docs] def log(self) -> Union[np.ndarray, "torch.Tensor"]: """Return tangent vector [omega_x, omega_y, omega_z].""" raise NotImplementedError
def __repr__(self): raise NotImplementedError def __str__(self): return self.__repr__() def __mul__(self, other: "SO3") -> "SO3": raise NotImplementedError def __rmul__(self, other: "SO3") -> "SO3": raise NotImplementedError def __matmul__(self, other: "SO3") -> "SO3": return self.__mul__(other) def __rmatmul__(self, other: "SO3") -> "SO3": return self.__rmul__(other)