How can I optimize my C++ code for a sum problem on CodeChef?

  • Context: C/C++ 
  • Thread starter Thread starter Saitama
  • Start date Start date
  • Tags Tags
    C++ Sum
Click For Summary

Discussion Overview

The discussion revolves around optimizing a C++ code implementation for calculating the sum of powers modulo a number, specifically the expression $$\sum_{i=1}^n i^i \pmod m$$. The context includes algorithmic efficiency and potential improvements to handle large values of n, which can be as high as $10^{18}$.

Discussion Character

  • Technical explanation
  • Debate/contested
  • Mathematical reasoning

Main Points Raised

  • One participant suggests using a table to store precomputed squares modulo m to speed up calculations, but another counters that this may not significantly improve performance due to cache misses and the necessity of iterating through a large range.
  • Another participant proposes factoring m into its prime factors and using the Chinese Remainder Theorem (CRT) to simplify the sum, expressing uncertainty about how to apply CRT in this context.
  • There is a discussion about the feasibility of performing $2^{60}$ iterations within a reasonable time frame, highlighting the need for a more efficient outer loop.
  • A later reply suggests that if n is large and m is small, caching might be effective, potentially allowing for a solution that meets time constraints.
  • One participant shares a Java implementation that can handle very large integers, discussing the limitations of C++'s long long type for large n and the advantages of using Java's BigInteger class.

Areas of Agreement / Disagreement

Participants express differing views on the effectiveness of precomputation and caching strategies, with no consensus on the best optimization approach. The discussion remains unresolved regarding the application of CRT and the overall strategy for handling large n.

Contextual Notes

Participants note that the outer loop must be optimized due to the impracticality of iterating through all values up to n when n is extremely large. There are also discussions about the limitations of different programming languages in handling large integers.

Saitama
Messages
4,244
Reaction score
93
I am trying this problem on CodeChef: Just a simple sum

My task is to evaluate:
$$\sum_{i=1}^n i^i \pmod m$$

Following is the code I have written:
Code:
#include <iostream>
using namespace std;

typedef long long ll;

ll modularPower(ll base, ll exponent, ll M)
{
    ll res = 1;
    while (exponent)
    {
        if (exponent & 1)
            res = (res * base) % M;
        exponent >>= 1;
        base = (base * base) % M;
    }
    return res;
}

int main() {
	ll T, N, M, i, res;
	cin>>T;
	while(T--) {
	    cin>>N>>M;
	    res=0;
	    for(i=1; i<=N; ++i)
	    {
	            res+=modularPower(i, i, M);
	            res%=M;
	    }
	    cout<<res<<endl;
	}
	return 0;
}

But this code exceeds the time limit.

How do I optimise it?

Thanks!
 
Technology news on Phys.org
Hey Pranav! (Wave)

Standard speed-up of algorithms is by tracking intermediate results in a table.

First thing I can think of, is to calculate the square of each number up to $m$ and store it $\bmod m$.
So instead of calculating [m](base * base) % M[/m], we can simply look it up. (Thinking)
 
I like Serena said:
Hey Pranav! (Wave)

Standard speed-up of algorithms is by tracking intermediate results in a table.

First thing I can think of, is to calculate the square of each number up to $m$ and store it $\bmod m$.
So instead of calculating [m](base * base) % M[/m], we can simply look it up. (Thinking)

That won't really do much, a 64-bit modular multiplication is quite fast compared to a cache miss (because the precomputed table is going to be quite large and not accessed in a cache-friendly way). Besides, you still end up doing a loop from 1 to n, and since n can be up to $10^{18} \approx 2^{60}$ you're still going to be doing that many iterations, and $2^{60}$ iterations of anything is not going to happen in under 80 seconds. So the outer loop must go. My approach would be to factor m into its constituent prime factors, and use that information together with the CRT to try and simplify the sum.
 
Bacterius said:
That won't really do much, a 64-bit modular multiplication is quite fast compared to a cache miss (because the precomputed table is going to be quite large and not accessed in a cache-friendly way). Besides, you still end up doing a loop from 1 to n, and since n can be up to $10^{18} \approx 2^{60}$ you're still going to be doing that many iterations, and $2^{60}$ iterations of anything is not going to happen in under 80 seconds. So the outer loop must go. My approach would be to factor m into its constituent prime factors, and use that information together with the CRT to try and simplify the sum.

I think I can factorise $m$ but I have never used CRT before. I looked it up online but could not understand how I would apply it here.

Can you please elaborate the approach? Thanks!
 
Bacterius said:
That won't really do much, a 64-bit modular multiplication is quite fast compared to a cache miss (because the precomputed table is going to be quite large and not accessed in a cache-friendly way). Besides, you still end up doing a loop from 1 to n, and since n can be up to $10^{18} \approx 2^{60}$ you're still going to be doing that many iterations, and $2^{60}$ iterations of anything is not going to happen in under 80 seconds. So the outer loop must go. My approach would be to factor m into its constituent prime factors, and use that information together with the CRT to try and simplify the sum.

Actually, it will have a significant impact.
That is because if $n$ is large, $m$ is likely very small in the test case provided.
So it will hit the cache successfully every time.
It may be enough to get the problem accepted, and the effort involved is trivial. (Nod)I do agree that an $n = 10^{18}$ is way too large to complete in a reasonable amount of time.
If that is really the case, we indeed need a smarter outer loop.

I see a couple of ways that might split it up:
First we have:
$$1^1 + 2^2 + ... + m^m + (m+1)^{m+1} + ... + n^n \\
= 1^1 + 2^2 + ... + 0^m + 1^1\cdot 1^m + ... + (n \bmod m)^{n \bmod m}\cdot ((n \bmod m)^m)^{n / m}
$$
I'd have to think some more on it, but perhaps we could merge the bases to a separate summation that we can evaluate.
It appears there is a geometric series in there. (Thinking)

Or we can run an sieve similar to Erathosthenes's sieve.
If we know the prime factorization of $m$ and also the greatest common divisor $d=\gcd(m,i)$, we can use:
$$\left(\frac a d\right)^{\phi(m/d)} \equiv 1 \pmod {\frac m d} \Rightarrow a^{\phi(m/d)} \equiv d^{\phi(m/d)} \pmod m$$
That's as far as I got. (Thinking)
 
Pranav,
I read your problem just a few days ago. I found it to be interesting and had fun finding a solution. First, a discussion of my solution. Then a formal implementation in Java, not C++. More of this choice of language later.

15dxssj.png

2u73fhi.png


For n reasonably small, you can use a C++ long long type or a Java long. However, if $n>2^{63}-1$, you need the ability to calculate with such large ints. This is mainly why I chose Java since Java comes pre packaged with a BigInteger class.
Code:
package bigsummod;

import java.math.BigInteger;

/**
 *
 * @author John on 7/6/2015
 */
public class BigSumMod {

   public static void main(String[] args) {
      SumModM test = new SumModM();
      long n = 1;
      int i;
      for (i = 1; i <= 18; i++) {
         n *= 10;
      }
      long time1 = System.currentTimeMillis(), time2;
      test.set(104729, n);
      System.out.println("The 10,000th prime is p = 104729");
      System.out.println("For m = p and n=i*10^18 for i=1 to 6:");
      for (i = 1; i <= 6; i++) {
         test.set(i * n);
         System.out.println(test.computeSum());
      }
      System.out.println("time " + (System.currentTimeMillis() - time1)+" milliseconds.");
      BigInteger N = BigInteger.TEN;
      N = N.pow(100); // N is a googol
      System.err.println("For m=p and n=10^100, a googol:");
      time1 = System.currentTimeMillis();
      test.set(N);
      System.out.println(test.computeSum());
      System.out.println("time " + (System.currentTimeMillis() - time1)+" milliseconds.");      
   }
}

package bigsummod;

/**
 * Class to compute sum(i = 1 to n)i^i mod m.  Here m<=200000 and n is "large", say 10^18
 * or even bigger than a long (2^63-1), say the "big" integer 10^100.
 */
import java.math.BigInteger;

public class SumModM {

   int m;
   long n;
   BigInteger N, N_j, R;
   boolean nIsBig;
   int[] primes; // array of prime divisors of m
   int[] primePowers; // m = product of components; each component a prime power
   int[] primePowerExponents; // array of exponents of the primes
   int primeCount; // number of prime divisors of m

   SumModM() {
      primes = new int[10]; // 10 is plenty since m<=200000
      primePowers = new int[10];
      primePowerExponents = new int[10];
   }
   
   public void set(int m,long n) {
      this.m=m;
      this.n=n;
      primeCount=0;
      factorM();
      nIsBig = false;
   }
   
   public void set(long n) {
      this.n = n;
      nIsBig = false;
   }
   
   public void set(BigInteger n) {
      N = n;
      nIsBig = true;
   }
   
/* crude factoring method of m, but fine for "small" m */   
   private void factorM() {
      int n = m, d;
      if ((n & 1) == 0) { // 2 divides n
         primePowers[0] = primes[0] = 2;
         primePowerExponents[0]=1;
         n >>= 1;
         while ((n & 1) == 0) {
            primePowers[0] *= 2;
            primePowerExponents[0]++;
            n >>= 1;
         }
         primeCount = 1;
      }
      d = 3;
      while (d <= n / d) {
         if (n % d == 0) {
            primes[primeCount] = primePowers[primeCount] = d;
            primePowerExponents[primeCount] = 1;
            n /= d;
            while (n % d == 0) {
               primePowers[primeCount] *= d;
               primePowerExponents[primeCount]++;
               n /= d;
            }
            primeCount++;
         }
         d += 2;
      }
      if (n != 1) {
         primes[primeCount] = primePowers[primeCount] = n;
         primePowerExponents[primeCount] = 1;
         primeCount++;
      }
   }
   
/* Computes x^m mod n.  Notice the method requires m*m <= max value for long , namely 2^63-1,
 * so it's perfectly safe for m an int..  This * recursive version is essentially just as efficient as a
 * non-recursive version.
 */
   private long powMod(long x, long m, long n) {
      long result;
      if (m == 1) {
         return (x);
      }
      result = powMod(x, m >> 1, n);
      result = (result * result) % n;
      if ((m & 1) == 1) {
         result = (x * result) % n;
      }
      return (result);
   }
   
   int computeSum() {
      int m1 = computeSum(0), modulus1 = primePowers[0];
      int i, m2;
      for (i = 1; i < primeCount; i++) {
         m2 = computeSum(i);
         m1 = chineseRemainder(m1, modulus1, m2, primePowers[i]);
         modulus1 *= primePowers[i];
      }
      return (m1);
   }

   int computeSum(int i) {
      int p = primes[i];
      m = primePowers[i];
      int j, sum = 0;
      long term;
      for (j = 1; j < m; j++) {
         if (j%p==0) {  // try to save a few milliseconds
            if (primePowerExponents[i]<=j) {
               continue;
            }
         }
         term = powMod(j, j, m);
         term = (term * computeSum_j(j, p, m)) % m;
         sum = (int) ((sum + term) % m);
      }
      return (sum);
   }

   int computeSum_j(int j, int p, int m) {
      if (j % p == 0) {
         return (1);
      }
      int x;
      long result, r, inverse;
      if (!nIsBig) {
         long n_j = (n - j) / m;
         if (p == 2) {
            return ((int) ((n_j + 1) % m));
         }
         x = (int) powMod(j, m / p, m);
         if (x == 1) {
            return ((int) ((n_j + 1) % m));
         }
         result = 0;
         r = (n_j + 1) % (p - 1);
         if (r != 0) {
            inverse = aInverseModB(x - 1, m);
            result = ((powMod(x, r, m) - 1) * inverse) % m;
         }
         return ((int) result);
      }
      N_j = N.subtract(BigInteger.valueOf(j));
      N_j = N_j.divide(BigInteger.valueOf(m));
      N_j = N_j.add(BigInteger.ONE);
      if (p == 2) {
         R = N_j.mod(BigInteger.valueOf(m));
         return (R.intValue());
      }
      x = (int) powMod(j, m / p, m);
      if (x == 1) {
         R = N_j.mod(BigInteger.valueOf(m));
         return (R.intValue());
      }
      result = 0;
      R = N_j.mod(BigInteger.valueOf(p - 1));
      if (!R.equals(BigInteger.ZERO)) {
         inverse = aInverseModB(x - 1, m);
         r = R.intValue();
         result = ((powMod(x, r, m) - 1) * inverse) % m;
      }
      return ((int) result);
   }
/* Upon entry m and n are positive coprime ints and a, b are non-negative ints reduced mod m and mod n
 * respectively.  The return is x with 0<=x and x<m*n  where x is congruent to a mod m and x is congruent
 * to b mod n.  Here's how it is done:
 * Clearly a+km is congruent to a for any int k;  so just need a k with a+km congruent to b mod n.  That is,
 * k should be congruent to (b-a)*(inverse of m mod n).  Then easily with 0<=k<n, x is of the correct size.
 */
   int chineseRemainder(int a, int m, int b, int n) {
      int k = (b - a) * aInverseModB(m, n);
      k %= n;
      if (k < 0) {
         k += n;
      }
      return (a + k * m);
   }

   /* Computation for small n's.  Check for the result of the above algorithm. */
   int feasible() {
      int j;
      int sum = 1;
      long term;
      for (j = 2; j <= n; j++) {
         term = powMod(j, j, m);
         sum = (int) ((sum + term) % m);
      }
      return (sum);
   }
   /*  "Standard computation of a inverse mod b where a is prime to b. */
   int aInverseModB(int a, int b) {
      int a1 = a;
      int b1 = b;
      int m0 = 1,  m1 = 0,  r, q, m2;
      while (b1 != 0) {
         r = a1 % b1;
         q = a1 / b1;
         m2 = m0 - m1 * q;
         m0 = m1;
         m1 = m2;
         a1 = b1;
         b1 = r;
      }
      if (m0 < 0) {
         m0 = b + m0;
      }
      return (m0);
   }
}
Finally, here's some timings for a couple of runs of the program:
For m=10^5 and n=i*10^18 for i=1 to 6:
62500
62500
62500
62500
62500
62500
time 16 milliseconds.
For m=10^5 and n=10^100, a googol:
62500
time 16 milliseconds.

You might want to verify for yourself that for m=10^5 and n any value with m^2 dividing n, you always get an answer of 62500. So m=100000 is an easy case. Notice the time for a googol is still very fast.

The 10,000th prime is p = 104729
For m = p and n=i*10^18 for i=1 to 6:
52444
39642
378
64098
7399
11003
time 936 milliseconds.
For m=p and n=10^100, a googol:
23255
time 312 milliseconds.
 
Thanks a lot johng! That is exactly what I was looking for, great approach and explanation! :)
 

Similar threads

  • · Replies 6 ·
Replies
6
Views
2K
  • · Replies 22 ·
Replies
22
Views
4K
  • · Replies 5 ·
Replies
5
Views
3K
  • · Replies 8 ·
Replies
8
Views
2K
  • · Replies 6 ·
Replies
6
Views
12K
  • · Replies 39 ·
2
Replies
39
Views
5K
  • · Replies 25 ·
Replies
25
Views
3K
  • · Replies 1 ·
Replies
1
Views
2K
  • · Replies 118 ·
4
Replies
118
Views
10K
  • · Replies 23 ·
Replies
23
Views
3K