Memory efficient numpy.random.choice()
NumPy’s numpy.random.choice()
method samples items at random from an array.
Unfortunately, sampling without replacement an int
from $0$ to $N-1$ requires $O(N)$ memory. Thus, the following code will result in an OOM error (if not completely crash your machine):
import numpy as np
N = 1_000_000_000_000
np.random.choice(N, size=2, replace=False)
This is because numpy.random.choice()
creates (at least up to NumPy version 1.24) an array of size $N$ (in our example, ~465 GB !), as it can be seen in the source code:
def choice(self, a, size=None, replace=True, p=None):
...
# pop_size == a when `a` is an `int`
idx = self.permutation(pop_size)[:size]
...
if a.ndim == 0:
return idx
def permutation(self, object x):
...
if isinstance(x, (int, np.integer)):
# this creates an array of size `x`
arr = np.arange(x)
self.shuffle(arr)
return arr
However, intuitively we shouldn’t need to create an array of size $N$ if we are sampling a small amount of items $k$, right?
We could simply keep sampling with replacement (which doesn’t require $O(N)$ memory) until we have an array of $k$ different items.
This very basic and straightforward algorithm would be defined as follows:
Algorithm 1 An algorithm to draw without replacement. $k>0$ items are taken without replacement from a universe of size $N$. $choice(N, k)$ samples $k$ items with replacement with a space complexity of $O(k \log{N})$:
Require: $N \geq (2 + \sqrt{2}) \cdot k$
1: $X \leftarrow \\{choice(N,1)\\}$
2: while $\vert X\vert \neq k$ do
3: $s \leftarrow choice(N, 1)$
4: if $s \notin X$ then
5: $X \leftarrow X \cup \left\\{ s \right\\}$
6: end if
7: end while
Which in code looks like,
def choice(N: int, k: int = 1, replace: bool = True, p=None):
"""Sample `k` elements from 0 to `N-1` with or without replacement."""
if not replace and N >= 3.414213562373095 * k:
X = [np.random.choice(N, size=1, replace=True, p=p)]
X = set(X)
while len(X) != k:
s = np.random.choice(N, size=1, replace=True, p=p)
X.add(s)
return X
else:
return np.random.choice(N, size=k, replace=replace, p=p)
But why require that $N \geq (2 + \sqrt{2})\cdot k$? We will figure it out by proving a couple of theorems:
Proof. Both $X$ and $choice(N, k)$ occupy $O(k \log{N})$ space, since $\vert X\vert$ is at most $k$ and our largest number, $N$, requires $O(\log{N})$ bits to be stored.
This one was very simple. But, how about time complexity?
Proof. In order to advance in the while loop, we need to add new elements to $X$. Thus: $$ \begin{align} P\left[\text{sampling unique item after }i+1\text{ draws}\right] &= \left(\frac{\vert X\vert}{N}\right)^i \cdot \left( 1 - \frac{\vert X\vert}{N} \right) \\ &\leq \left(\frac{k-1}{N}\right)^i \cdot 1 \end{align} $$ Using this probability bound, the expected number of items drawn is: $$ \begin{align} \mathbb{E}[\text{# of items drawn}] \leq 1 &+ \mathbb{E}[\text{# of draws to increment }\vert X\vert=1\text{ by 1}] \\ &+ \cdots \\ &+ \mathbb{E}[\text{# of draws to increment }\vert X\vert=(k-1)\text{ by 1}] \\ \leq 1 &+ (k-1)\sum_{i=1}^\infty i\left(\frac{k-1}{N}\right)^{i} \end{align} $$ Since we have to draw 1 item in line 1 of the algorithm and then we have to draw $(k-1)$ new unique items in the while loop. $$ \begin{align} N \geq (2 + \sqrt{2}&)\cdot k \geq 2\cdot k \newline \frac{k-1}{N} &\leq 2^{-1} \newline &\downarrow \newline \mathbb{E}[\text{\# of items drawn}] &\leq 1 + (k-1) \sum_{i=1}^\infty i\cdot2^{-i} \newline &= 1 + (k-1)\cdot 2 \newline &= O(k) \end{align} $$
Note that sets in Python have $O(1)$ lookup and insert times, required for lines 4 and 5 in the algorithm.
However, this is not all. We have used the fact that $N \geq 2k$ in our proof, but now let’s generalize to $N \geq ak$ with $a > 1$. Note that $a=1$ would include the possibility that $N=k$, in which we just sample all items in the universe of size $N$, and this is not interesting for our purposes. Also, $a<1$ includes cases where $N With this lemma we can prove the following theorem:
Thus, in the worst case we always draw $O(1+2k)=O(2k)$ items w.p. less than $(k-1)/N$. This is very useful. It means that if $N$ is massive and $k$ is minuscule in comparison, we won’t draw more than $1+2k$ items with high probability. In fact, if $N$ is 100 times $k$, then we will draw $1+1.021\cdot k$ items w.p. less than ~1%. This is definitely better than the original algorithm, which when creating the array of length $N$ had to spend $O(N)$ time. To end this little exploration, note that in general if we had chosen $\delta=N/(k-1)k^\gamma$ for some $\gamma>0$, then we would have that the algorithm is $O\left(1+k^{1+\gamma}\left(1-\frac{k}{N}\right)^{-2}\right)$ w.p. less than $k^{1-\gamma}/N$. For example, the algorithm is $O(k^2)$ w.p. less than $1/N$. This means that drawing 10 items without replacement from a universe of 1'000 items using this algorithm will require drawing 103 items w.p. less than 0.1%. But, when do we have a massive $N$ in practice? Let me present the original problem I was trying to solve which required me to create my own choice function: a memory-efficient implementation to sample without replacement Let’s say that you have multiple iterables that can be indexed in Python and you want to sample $k$ random permutations of their items. Since the space of all permutations is prohibitively large, we cannot simply generate all permutations first and then sample $k$ of them. The solution is to use our $choice(N, k)$ function:
Proof. $N\geq ak$ implies that,
$$
\begin{align}
\frac{k-1}{N} &\leq a^{-1} \\
&\downarrow \\
\mathbb{E}[\text{# of items drawn}] &\leq 1 + (k-1) \sum_{i=1}^\infty i\cdot a^{-i} \\
&= 1 + (k-1)\frac{a}{(a-1)^2}
\end{align}
$$
Since,
$$
\sum_{i=1}^\infty i\cdot a^{-i} = \frac{a}{(a-1)^2} \qquad \forall a, \vert a\vert > 1
$$
Proof. We define $Y := (\text{# of items drawn} - 1)$, which is the amount of items drawn in the while loop. We recall Markov's inequality:
$$
P\left[ Y \geq \delta\cdot \mathbb{E}[Y] \right] \leq \frac{1}{\delta} \qquad \delta > 0
$$
So that, using Lemma 1,
$$
P\left[ Y \geq \delta\cdot (k-1) \cdot \frac{a}{(a-1)^2} \right] \leq P\left[ Y \geq \delta\cdot \mathbb{E}[Y] \right] \leq \frac{1}{\delta}
$$
Now, we set $\delta = N/(k-1)$, and consider the maximum value of $a$, $a^*=N/k$,
$$
P\left[ Y \geq N \cdot \frac{a^*}{(a^*-1)^2} \right] \leq \frac{k-1}{N}
$$
Note that this is true since Lemma 1 still applies when $a=a^*$. We now observe that,
$$
N \frac{a^*}{(a^*-1)^2} = k\frac{N^2}{(N-k)^2} = k\left(1 - \frac{k}{N}\right)^{-2}
$$
$$
P\left[ Y \geq k\cdot\left(1-\frac{k}{N}\right)^{-2} \right] \leq \frac{k-1}{N}
$$$$
N \geq \left(2 + \sqrt{2}\right)\cdot k \quad \rightarrow \quad \left(1-\frac{k}{N}\right)^{-2} \leq \left(1-\frac{1}{2 + \sqrt{2}}\right)^{-2} = 2
$$Example use-case
itertools.product()
.def random_product(*iterables, size: int = None) -> tuple:
"""Memory efficient way to sample ``itertools.product`` at random without replacement."""
lens = [len(it) for it in iterables]
# N is usually extremely large!
N = np.prod(lens)
if size is None:
size = N
idx = np.unravel_index(choice(N, k=size, replace=False),
shape=lens)
idx = np.array(idx).T
for i_ in idx:
yield tuple(iterables[i][i_[i]] for i in range(len(iterables)))