a_tensors = [
        x for x in nest.flatten(a, expand_composites=True)
        if isinstance(x, tensor.Tensor)