Reducing Sigmoid computations by (at least) 88.0797077977882%

A classic implementation issue in machine learning is reducing the cost of computing the sigmoid function

\sigma(a) = \frac{1}{1+\exp(-a)}.

Specifically, it is common to profile your code and discover that 90% of the time is spent computing the \exp in that function.  This comes up often in neural networks, as well as in various probabilistic architectures, such as sampling from Ising models or Boltzmann machines.  There are quite a few classic approximations to the function, using simple polynomials, etc. that can be used in neural networks.

Today, however, I was faced with a sampling problem involving the repeated use of the sigmoid function, and I noticed a simple trick that could reduce the number of sigmoids by about 88% without introducing any approximation.  The particular details of the situation aren’t interesting, but I repeatedly needed to do something like the following:

  1. Input a \in \Re
  2. Compute a random number r \in [0,1]
  3. If r < \sigma(a)
  4.   Output +1
  5. Else
  6.     Output -1

Now, let’s assume for simplicity that a is positive.  (Otherwise, sample using -a and then switch the sign of the output.)  There are two observations to make:

  1. If a is large, then you are likely to output +1
  2. Otherwise, there are easy upper and lower bounds on the probability of outputting +1

This leads to the following algorithm:

  1. Input a \in \Re
  2. Compute a random number r \in [0,1]
  3. If a \geq 2
  4.     If r \leq 0.880797077977882 or r \leq \sigma(a)
  5.         Output +1
  6.     Else
  7.         Output -1
  8. Else
  9.     If r > .5 + a/4
  10.         Output -1
  11.     Else if r \leq .5 + a/5.252141128658 or r \leq \sigma(a)
  12.         Output +1
  13.     Else
  14.         Output -1

The idea is as follows:

  1. If a\geq 2, then we can lower-bound the probability of outputting +1 by a pre-computed value of \sigma(2)\approx0.8807..., and short-circuit the computation in many cases.
  2. If a\leq 2, then we can upper bound the sigmoid function by .5+a/4.
  3. If a\leq 2, then we can also lower bound by .5+a/5.252141... respectively.  (This constant was found numerically).

The three cases are illustrated in the following figure, where the input a is on the x-axis, and the random number r is on the y-axis.


Since, for all a at least a fraction \sigma(2)\approx.8807 of the numbers will be short-circuited, sigmoid calls will be reduced appropriately.  If a is often near zero, you will do even better.

Obviously, you can take this farther by adding more cases, which may or may not be helpful, depending on the architecture, and the cost of branching vs. the cost of computing an \exp.

9 thoughts on “Reducing Sigmoid computations by (at least) 88.0797077977882%

  1. Neat. This technique is known as rejection sampling (a famous example is the Ziggurat algorithm for the Normal distribution).

  2. Wait, actually this isn’t the same thing as rejection sampling, but it’s basically the same idea.

  3. Alternatively, you could use the “Gumbel max trick”(, which avoids the exp because you can stay in the natural parameterization. However, the Gumbel-max-trick requires twice as many random numbers and is typically generated by apply two logs to a uniform variate $-log(-log(uniform))$. So why consider this at all? two logs is about as expensive as two exps and we only need one! The saving grace is that the Gumbel variates can be precomputed because they don’t depend on the data. If you amortize the cost of precomputing the Gumbel variates, the Gumbel max-trick is wicked fast.

  4. Hi Tim,

    I love it! If I understand it correctly, you just do the following:

    1. Input a
    2. Sample x \sim \text{Gumbel}
    3. Sample y \sim \text{Gumbel}
    4. If x + a > y output +1
    5. Else, output -1

  5. And, yes, if you want identical samples for the same r, you should negate it. (Though you still get samples from the same distribution if you don’t negate, so it probably isn’t worth the effort in most cases.)

  6. Even simpler method, which is similar the the Gumbel trick.

    Let u = uniform(0,1)
    return +1 if a > logit(u) else -1

    It’s super obvious in hindsight!
    * logit is the inverse of sigmoid.
    * logit is strictly monotonic increasing you can apply it both sides of the greater than,

    sigmoid(a) > u logit(sigmoid(a)) > logit(u) a > logit(u).

    Bonus: It also turns out that the difference of Gumbel RVs is logistic, which is unsurprising given that my previous. So, generating two Gumbel RVs is unnecessary!

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s