def test_missing_treatment_indicator_from_inputs_during_training_raises_value_error(
      self,
  ):
    model = self._get_compiled_model()
    inputs = {"x": tf.ones((3, 1))}