with self.cached_session():
      def f(x):  # pylint: disable=invalid-name
        return nn_impl.swish(x, beta=0.5)

      theoretical, numerical = gradient_checker_v2.compute_gradient(