expected_output = {
        "a":
            np.array(
                [[1, 1], [3, -3], [3, -3]],
                dtype=np.float32).reshape(3, 1, 2, 1),