Source code for sionna.phy.block

## SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.# SPDX-License-Identifier: Apache-2.0#"""Definition of Sionna PHY Object and Block classes"""fromabcimportABCfromabcimportabstractmethodimporttensorflowastfimportnumpyasnpfrom.configimportconfig,dtypes
[docs]classObject(ABC):"""Abstract class for Sionna PHY objects Parameters ---------- precision : `None` (default) | "single" | "double" Precision used for internal calculations. If set to `None`, the default :attr:`~sionna.phy.config.Config.precision` is used. """# pylint: disable=unused-argumentdef__init__(self,*args,precision=None,**kwargs):ifprecisionisNone:self._precision=config.precisionelifprecisionin['single','double']:self._precision=precisionelse:raiseValueError("'precision' must be 'single' or 'double'")@propertydefprecision(self):""" `str`, "single" | "double" : Precision used for all compuations """returnself._precision@propertydefcdtype(self):""" `tf.complex` : Type for complex floating point numbers """returndtypes[self.precision]['tf']['cdtype']@propertydefrdtype(self):""" `tf.float` : Type for real floating point numbers """returndtypes[self.precision]['tf']['rdtype']def_cast_or_check_precision(self,v):"""Cast tensor to internal precision or check if a variable has the right precision """# Check correct dtype for Variablesifisinstance(v,tf.Variable):ifv.dtype.is_complex:ifv.dtype!=self.cdtype:msg=f"Wrong dtype. Expected{self.cdtype}"+ \f", got{v.dtype}"raiseValueError(msg)elifv.dtype.is_floating:ifv.dtype!=self.rdtype:msg=f"Wrong dtype. Expected{self.cdtype}"+ \f", got{v.dtype}"raiseValueError("Wrong dtype")# Cast tensors to the correct dtypeelse:ifnotisinstance(v,tf.Tensor):v=tf.convert_to_tensor(v)ifv.dtype.is_complex:v=tf.cast(v,self.cdtype)else:v=tf.cast(v,self.rdtype)returnv
[docs]classBlock(Object):"""Abstract class for Sionna PHY processing blocks Parameters ---------- precision : `None` (default) | "single" | "double" Precision used for internal calculations and outputs. If set to `None`, the default :attr:`~sionna.phy.config.Config.precision` is used. """# pylint: disable=unused-argumentdef__init__(self,*args,precision=None,**kwargs):super().__init__(precision=precision,**kwargs)# Boolean flag indicating if the block's build function has been called# This will prevent rebuilding the block in Eager mode each time it is# called.self._built=False@propertydefbuilt(self):""" `bool` : Indicates if the blocks' build function was called """returnself._built
[docs]defbuild(self,*arg_shapes,**kwarg_shapes):""" Method to (optionally) initialize the block based on the inputs' shapes """pass
[docs]@abstractmethoddefcall(self,*args,**kwargs):""" Abstract call method with arbitrary arguments and keyword arguments """raiseNotImplementedError("Subclasses must implement this method.")
def_convert_to_tensor(self,v):"""Casts floating or complex tensors to the block's precision"""ifisinstance(v,np.ndarray):v=tf.convert_to_tensor(v)ifisinstance(v,tf.Tensor):ifv.dtype.is_floating:v=tf.cast(v,self.rdtype)elifv.dtype.is_complex:v=tf.cast(v,self.cdtype)returnvdef_get_shape(self,v):"""Converts an input to the corresponding TensorShape"""try:v=tf.convert_to_tensor(v)except(TypeError,ValueError):passifhasattr(v,"shape"):returntf.TensorShape(v.shape)else:returntf.TensorShape([])def__call__(self,*args,**kwargs):args,kwargs=tf.nest.map_structure(self._convert_to_tensor,[args,kwargs])withtf.init_scope():# pylint: disable=not-context-managerifnotself._built:shapes=tf.nest.map_structure(self._get_shape,[args,kwargs])self.build(*shapes[0],**shapes[1])self._built=Truereturnself.call(*args,**kwargs)