Can you solve this in O(1) time?

By


When started learning programming, I randomly saw this problem on Facebook which now I forgot where the post is. But the problem goes like this:

Write a function with integer input k,s to count how many integer triplets (a,b,c) such that:
    0 \leq a,b,c \leq k \; \text{and} \; a+b+c=s

Note that the order of triplets matter. For example, when k=s=2, the function should return 6 since there are 6 triplets satisfying the condition: (2,0,0), (0,2,0), (0,0,2), (1,1,0), (1,0,1) and (0,1,1).

The first naive solution is looping through all triplets (a,b,c). Here is the Python code:

def f1(k,s):
    m=max(0,s-2*k)   #lower bound for a,b,c
    M=min(s,k)       #upper bound for a,b,c 
    result = 0
    for a in range(m,M+1):
        for b in range(m,M+1):
            for c in range (m,M+1):
                if (a + b + c == s):
                    result += 1
    return result

It’s easy to see that m\leq a,b,c \leq M and this lower/upper bounds hopefully may improve a bit the performance. Otherwise one can just loop through a,b,c =0,...,s.

This solution is of course correct, but very slow since it is O(n^3) time complexity. Actually we can avoid the third loop in c by a simple observation: for each a,b, if m \leq s-a-b \leq M, then automatically there is an integer c such that a+b+c=s. Thus we have a O(n^2) solution:

def f2(k,s):
    m=max(0,s-2*k)   #lower bound for a,b,c
    M=min(s,k)       #upper bound for a,b,c 
    result = 0
    for a in range (m,M+1):
        for b in range(m,M+1):
            if (m<= s- a - b <=M):
                result += 1       
    return result

I then discussed this problem with my friend, and after several days he gave me a brilliant O(n) solution:

#O(n) solution
def f3(k,s):
    M=min(s,k)
    res=0
    for a in range(k+1):
        if s-a>2*M or s-a<0:
            continue
        else:
            if s-a>=k:
                res+=2*M-(s-a)+1
            else:
                res+=s-a+1
    return res

Let’s me explain his idea: for each a, we count how many pairs (b,c) such that 0 \leq b,c \leq k \; \text{and} \; b+c=s-a

  • Two border cases: when s-a=2M or s-a=0, only 1 pair (b,c)=(M,M) or (b,c)=(0,0) respectively
  • As s-a \leq k: the #pairs increases from 1 to k+1. After that (i.e. when s-a = k+1 ) it starts decreasing to 1.
s-a01kk+12M
# pairs (b,c)12k+1k1
# pairs (b,c) such that b+c=s-a versus values of s-a

This solution is beautiful, isn’t it? At first I believed this was optimal solution. Then, however, I realized we can count exactly how many triplets without any loop. But this of course requires more work to analyze the problem mathematically.

O(1) solution

It’s easy to see that:
m:=max(0,s-2k)\leq a,b,c \leq M:=min(s,k)

Let’s change the point of view: for each fixed a \in [m,M], p:=b+c=s-a is also a fixed number. For each p, thinking of b+c=p as a line in the coordinate plane, we want to count how many integer points $ latex (b,c)$ on the line (the coordinates of such a point is a solution of the equation b+c=p), then sum them up for all p. This can be done by using 1 loop through a.

However, if we know lower bound and upper bound for p, we can calculate the sum without using loop. To do this, we need to study the range of p:=b+c.
Since b+c=s-a and m\leq a,b,c \leq M , we always have the “global” lower and upper bound for p=b+c:
u:=max(2m,s-M) \leq b+c \leq v:=min(2M, s-m)

Case 1: s-2k \geq 0. In this case we have:

  • m=s-2k \geq 0 and M=k (since s>k)
  • v=2M. Why?
    Proof. v=min(2M, s-m)=min(2M, 2k)=2M since M=k
  • u=max(2m,s-M) \geq M because we know that s-M \geq M. (this can be seen from s\geq 2k=2M)

Why we need this information? Starting from the line b+c=2M, we have only 1 option for (b,c), that is b=c=M (we know this line indeed occurs since we proved that v=2M) . If b+c=2M-1, we have 2 solutions and so on…We will stop at the line b+c=u, where we have v-u+1 integer pairs (or solutions) on this line, and since we know u\geq M, the expected result (= # total points) is 1+2+…+(v-u+1)=(v-u+2)(v-u+1).

Note that the condition u\geq M also indicates that we are counting the points in the admissible region. (here we want to count the points outside the square [0,m] \times [0,m] and inside the bigger square [0,M] \times [0,M].

Case 2: s-2k < 0. In this case we have:

  • m= 0
  • s<2M. Why?
    Proof. if k \leq s, then M=min(s,k)=k, and hence 2M=2k>s. On the other hand, if k > s, we have M=min(s,k)=s, then 2M=2s>s. In both cases, we have s<2M.
  • u=max(2m,s-M)=s-M < M because s<2M (note that m=0)
  • v=min(2M, s-m)=min(2M, s)=s<2M

Since m=0, we count all the points inside the square [0,M] \times [0,M] in this case. The last condition v<2M means that we don’t start the sum from 1 like in the first case. (we start at the line b+c=v which lies below the vertex (M,M) of the square. The condition u<M means that we will stop at the line b+c=u which lies somewhere below the diagonal of the square. We can count the points as follow:

result= (#all point in the square)- (#points under the line b+c=u) – (#points above the line b+c=v)


In formula, it is equal to (M+1)^2- u(u+1)/2 – (2M-v)(2M-v+1)/2.
Thus, we we have a O(1) solution:

def f4(k,s):
    m=max(0,s-2*k)   #lower bound for a,b,c
    M=min(s,k)       #upper bound for a,b,c 
    u=max(2*m,s-M)   #lower bound for b+c
    v=min(2*M, s-m)  #upper bound for b+c
    
    if s-2*k>=0:
        #res=1+2+...+ (v-u+1)
        return int((v-u+2)*(v-u+1)/2)
    else:
        res=(M+1)**2- u*(u+1)/2 - (2*M-v)*(2*M-v+1)/2
        return int(res)


This was just a toy problem rather than anything serious, but solving it brought me a surprising amount of joy when I first began learning programming. It reminded me that if I stay curious enough and keep digging, there is almost always a better solution waiting to be discovered.