When started learning programming, I randomly saw this problem on Facebook which now I forgot where the post is. But the problem goes like this:
Write a function with integer input
to count how many integer triplets
such that:

Note that the order of triplets matter. For example, when , the function should return 6 since there are 6 triplets satisfying the condition: (2,0,0), (0,2,0), (0,0,2), (1,1,0), (1,0,1) and (0,1,1).
The first naive solution is looping through all triplets . Here is the Python code:
def f1(k,s):
m=max(0,s-2*k) #lower bound for a,b,c
M=min(s,k) #upper bound for a,b,c
result = 0
for a in range(m,M+1):
for b in range(m,M+1):
for c in range (m,M+1):
if (a + b + c == s):
result += 1
return result
It’s easy to see that and this lower/upper bounds hopefully may improve a bit the performance. Otherwise one can just loop through
.
This solution is of course correct, but very slow since it is time complexity. Actually we can avoid the third loop in
by a simple observation: for each
, if
, then automatically there is an integer
such that
. Thus we have a
solution:
def f2(k,s):
m=max(0,s-2*k) #lower bound for a,b,c
M=min(s,k) #upper bound for a,b,c
result = 0
for a in range (m,M+1):
for b in range(m,M+1):
if (m<= s- a - b <=M):
result += 1
return result
I then discussed this problem with my friend, and after several days he gave me a brilliant O(n) solution:
#O(n) solution
def f3(k,s):
M=min(s,k)
res=0
for a in range(k+1):
if s-a>2*M or s-a<0:
continue
else:
if s-a>=k:
res+=2*M-(s-a)+1
else:
res+=s-a+1
return res
Let’s me explain his idea: for each , we count how many pairs
such that
- Two border cases: when
or
, only 1 pair
or
respectively
- As
: the #pairs increases from 1 to k+1. After that (i.e. when
) it starts decreasing to 1.
| s-a | 0 | 1 | … | k | k+1 | … | 2M |
| # pairs (b,c) | 1 | 2 | … | k+1 | k | … | 1 |
This solution is beautiful, isn’t it? At first I believed this was optimal solution. Then, however, I realized we can count exactly how many triplets without any loop. But this of course requires more work to analyze the problem mathematically.
O(1) solution
It’s easy to see that:
Let’s change the point of view: for each fixed ,
is also a fixed number. For each
, thinking of
as a line in the coordinate plane, we want to count how many integer points $ latex (b,c)$ on the line (the coordinates of such a point is a solution of the equation
), then sum them up for all
. This can be done by using 1 loop through
.
However, if we know lower bound and upper bound for , we can calculate the sum without using loop. To do this, we need to study the range of
.
Since and
, we always have the “global” lower and upper bound for
:
Case 1:
. In this case we have:
and
(since
)
. Why?
Proof.since
because we know that
. (this can be seen from
)
Why we need this information? Starting from the line , we have only 1 option for
, that is
(we know this line indeed occurs since we proved that
) . If
, we have 2 solutions and so on…We will stop at the line
, where we have
integer pairs (or solutions) on this line, and since we know
, the expected result (= # total points) is 1+2+…+(v-u+1)=(v-u+2)(v-u+1).
Note that the condition also indicates that we are counting the points in the admissible region. (here we want to count the points outside the square
and inside the bigger square
.
Case 2:
. In this case we have:
. Why?
Proof. if, then
, and hence
. On the other hand, if
, we have
, then
. In both cases, we have
.
because
(note that
)
Since , we count all the points inside the square
in this case. The last condition
means that we don’t start the sum from 1 like in the first case. (we start at the line
which lies below the vertex
of the square. The condition
means that we will stop at the line
which lies somewhere below the diagonal of the square. We can count the points as follow:
result= (#all point in the square)- (#points under the line b+c=u) – (#points above the line b+c=v)
In formula, it is equal to (M+1)^2- u(u+1)/2 – (2M-v)(2M-v+1)/2.
Thus, we we have a O(1) solution:
def f4(k,s):
m=max(0,s-2*k) #lower bound for a,b,c
M=min(s,k) #upper bound for a,b,c
u=max(2*m,s-M) #lower bound for b+c
v=min(2*M, s-m) #upper bound for b+c
if s-2*k>=0:
#res=1+2+...+ (v-u+1)
return int((v-u+2)*(v-u+1)/2)
else:
res=(M+1)**2- u*(u+1)/2 - (2*M-v)*(2*M-v+1)/2
return int(res)
This was just a toy problem rather than anything serious, but solving it brought me a surprising amount of joy when I first began learning programming. It reminded me that if I stay curious enough and keep digging, there is almost always a better solution waiting to be discovered.
