A classic implementation issue in machine learning is reducing the cost of computing the sigmoid function

.

Specifically, it is common to profile your code and discover that 90% of the time is spent computing the in that function. This comes up often in neural networks, as well as in various probabilistic architectures, such as sampling from Ising models or Boltzmann machines. There are quite a few classic approximations to the function, using simple polynomials, etc. that can be used in neural networks.

Today, however, I was faced with a sampling problem involving the repeated use of the sigmoid function, and I noticed a simple trick that could reduce the number of sigmoids by about 88% without introducing any approximation. The particular details of the situation aren’t interesting, but I repeatedly needed to do something like the following:

- Input
- Compute a random number
- If
- Output
- Else
- Output

Now, let’s assume for simplicity that is positive. (Otherwise, sample using and then switch the sign of the output.) There are two observations to make:

- If is large, then you are likely to output
- Otherwise, there are easy upper and lower bounds on the probability of outputting

This leads to the following algorithm:

- Input
- Compute a random number
- If
- If or
- Output
- Else
- Output
- Else
- If
- Output
- Else if or
- Output
- Else
- Output

The idea is as follows:

- If , then we can lower-bound the probability of outputting +1 by a pre-computed value of , and short-circuit the computation in many cases.
- If , then we can upper bound the sigmoid function by .
- If , then we can also lower bound by respectively. (This constant was found numerically).

The three cases are illustrated in the following figure, where the input is on the x-axis, and the random number is on the y-axis.

Since, for all at least a fraction of the numbers will be short-circuited, sigmoid calls will be reduced appropriately. If is often near zero, you will do even better.

Obviously, you can take this farther by adding more cases, which may or may not be helpful, depending on the architecture, and the cost of branching vs. the cost of computing an .

### Like this:

Like Loading...

*Related*

Neat. This technique is known as rejection sampling (a famous example is the Ziggurat algorithm for the Normal distribution).

Wait, actually this isn’t the same thing as rejection sampling, but it’s basically the same idea.

Alternatively, you could use the “Gumbel max trick”(https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/), which avoids the exp because you can stay in the natural parameterization. However, the Gumbel-max-trick requires twice as many random numbers and is typically generated by apply two logs to a uniform variate $-log(-log(uniform))$. So why consider this at all? two logs is about as expensive as two exps and we only need one! The saving grace is that the Gumbel variates can be precomputed because they don’t depend on the data. If you amortize the cost of precomputing the Gumbel variates, the Gumbel max-trick is wicked fast.

Also, in the case of negative $a$, I believe you need to negate $r$, i.e. $-sample(-a,1-r)$.

Hi Tim,

I love it! If I understand it correctly, you just do the following:

1. Input

2. Sample

3. Sample

4. If output

5. Else, output

Yup!

And, yes, if you want identical samples for the same , you should negate it. (Though you still get samples from the same distribution if you don’t negate, so it probably isn’t worth the effort in most cases.)

Ah, right. Thanks.

Even simpler method, which is similar the the Gumbel trick.

Let u = uniform(0,1)

return +1 if a > logit(u) else -1

It’s super obvious in hindsight!

* logit is the inverse of sigmoid.

* logit is strictly monotonic increasing you can apply it both sides of the greater than,

sigmoid(a) > u logit(sigmoid(a)) > logit(u) a > logit(u).

Bonus: It also turns out that the difference of Gumbel RVs is logistic, which is unsurprising given that my previous. So, generating two Gumbel RVs is unnecessary!