true_full_output = np.array(
        [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
        dtype=np.float32)