C/C++ How can I optimize my C++ code for a sum problem on CodeChef?

  • Thread starter Thread starter Saitama
  • Start date Start date
  • Tags Tags
    C++ Sum
AI Thread Summary
The discussion revolves around optimizing C++ code for calculating the sum of powers modulo a number in a CodeChef problem. The original approach, which involves a loop from 1 to n, is inefficient due to the potential size of n, which can reach up to 10^18. Suggestions for optimization include using the Chinese Remainder Theorem (CRT) to factor the modulus m into its prime factors and leveraging precomputed values to reduce computation time. Additionally, there are considerations for cache efficiency and the potential benefits of using a different programming language, such as Java, for handling large integers. The conversation emphasizes the need for a smarter algorithm to handle large inputs effectively.
Saitama
Messages
4,244
Reaction score
93
I am trying this problem on CodeChef: Just a simple sum

My task is to evaluate:
$$\sum_{i=1}^n i^i \pmod m$$

Following is the code I have written:
Code:
#include <iostream>
using namespace std;

typedef long long ll;

ll modularPower(ll base, ll exponent, ll M)
{
    ll res = 1;
    while (exponent)
    {
        if (exponent & 1)
            res = (res * base) % M;
        exponent >>= 1;
        base = (base * base) % M;
    }
    return res;
}

int main() {
	ll T, N, M, i, res;
	cin>>T;
	while(T--) {
	    cin>>N>>M;
	    res=0;
	    for(i=1; i<=N; ++i)
	    {
	            res+=modularPower(i, i, M);
	            res%=M;
	    }
	    cout<<res<<endl;
	}
	return 0;
}

But this code exceeds the time limit.

How do I optimise it?

Thanks!
 
Technology news on Phys.org
Hey Pranav! (Wave)

Standard speed-up of algorithms is by tracking intermediate results in a table.

First thing I can think of, is to calculate the square of each number up to $m$ and store it $\bmod m$.
So instead of calculating [m](base * base) % M[/m], we can simply look it up. (Thinking)
 
I like Serena said:
Hey Pranav! (Wave)

Standard speed-up of algorithms is by tracking intermediate results in a table.

First thing I can think of, is to calculate the square of each number up to $m$ and store it $\bmod m$.
So instead of calculating [m](base * base) % M[/m], we can simply look it up. (Thinking)

That won't really do much, a 64-bit modular multiplication is quite fast compared to a cache miss (because the precomputed table is going to be quite large and not accessed in a cache-friendly way). Besides, you still end up doing a loop from 1 to n, and since n can be up to $10^{18} \approx 2^{60}$ you're still going to be doing that many iterations, and $2^{60}$ iterations of anything is not going to happen in under 80 seconds. So the outer loop must go. My approach would be to factor m into its constituent prime factors, and use that information together with the CRT to try and simplify the sum.
 
Bacterius said:
That won't really do much, a 64-bit modular multiplication is quite fast compared to a cache miss (because the precomputed table is going to be quite large and not accessed in a cache-friendly way). Besides, you still end up doing a loop from 1 to n, and since n can be up to $10^{18} \approx 2^{60}$ you're still going to be doing that many iterations, and $2^{60}$ iterations of anything is not going to happen in under 80 seconds. So the outer loop must go. My approach would be to factor m into its constituent prime factors, and use that information together with the CRT to try and simplify the sum.

I think I can factorise $m$ but I have never used CRT before. I looked it up online but could not understand how I would apply it here.

Can you please elaborate the approach? Thanks!
 
Bacterius said:
That won't really do much, a 64-bit modular multiplication is quite fast compared to a cache miss (because the precomputed table is going to be quite large and not accessed in a cache-friendly way). Besides, you still end up doing a loop from 1 to n, and since n can be up to $10^{18} \approx 2^{60}$ you're still going to be doing that many iterations, and $2^{60}$ iterations of anything is not going to happen in under 80 seconds. So the outer loop must go. My approach would be to factor m into its constituent prime factors, and use that information together with the CRT to try and simplify the sum.

Actually, it will have a significant impact.
That is because if $n$ is large, $m$ is likely very small in the test case provided.
So it will hit the cache successfully every time.
It may be enough to get the problem accepted, and the effort involved is trivial. (Nod)I do agree that an $n = 10^{18}$ is way too large to complete in a reasonable amount of time.
If that is really the case, we indeed need a smarter outer loop.

I see a couple of ways that might split it up:
First we have:
$$1^1 + 2^2 + ... + m^m + (m+1)^{m+1} + ... + n^n \\
= 1^1 + 2^2 + ... + 0^m + 1^1\cdot 1^m + ... + (n \bmod m)^{n \bmod m}\cdot ((n \bmod m)^m)^{n / m}
$$
I'd have to think some more on it, but perhaps we could merge the bases to a separate summation that we can evaluate.
It appears there is a geometric series in there. (Thinking)

Or we can run an sieve similar to Erathosthenes's sieve.
If we know the prime factorization of $m$ and also the greatest common divisor $d=\gcd(m,i)$, we can use:
$$\left(\frac a d\right)^{\phi(m/d)} \equiv 1 \pmod {\frac m d} \Rightarrow a^{\phi(m/d)} \equiv d^{\phi(m/d)} \pmod m$$
That's as far as I got. (Thinking)
 
Pranav,
I read your problem just a few days ago. I found it to be interesting and had fun finding a solution. First, a discussion of my solution. Then a formal implementation in Java, not C++. More of this choice of language later.

15dxssj.png

2u73fhi.png


For n reasonably small, you can use a C++ long long type or a Java long. However, if $n>2^{63}-1$, you need the ability to calculate with such large ints. This is mainly why I chose Java since Java comes pre packaged with a BigInteger class.
Code:
package bigsummod;

import java.math.BigInteger;

/**
 *
 * @author John on 7/6/2015
 */
public class BigSumMod {

   public static void main(String[] args) {
      SumModM test = new SumModM();
      long n = 1;
      int i;
      for (i = 1; i <= 18; i++) {
         n *= 10;
      }
      long time1 = System.currentTimeMillis(), time2;
      test.set(104729, n);
      System.out.println("The 10,000th prime is p = 104729");
      System.out.println("For m = p and n=i*10^18 for i=1 to 6:");
      for (i = 1; i <= 6; i++) {
         test.set(i * n);
         System.out.println(test.computeSum());
      }
      System.out.println("time " + (System.currentTimeMillis() - time1)+" milliseconds.");
      BigInteger N = BigInteger.TEN;
      N = N.pow(100); // N is a googol
      System.err.println("For m=p and n=10^100, a googol:");
      time1 = System.currentTimeMillis();
      test.set(N);
      System.out.println(test.computeSum());
      System.out.println("time " + (System.currentTimeMillis() - time1)+" milliseconds.");      
   }
}

package bigsummod;

/**
 * Class to compute sum(i = 1 to n)i^i mod m.  Here m<=200000 and n is "large", say 10^18
 * or even bigger than a long (2^63-1), say the "big" integer 10^100.
 */
import java.math.BigInteger;

public class SumModM {

   int m;
   long n;
   BigInteger N, N_j, R;
   boolean nIsBig;
   int[] primes; // array of prime divisors of m
   int[] primePowers; // m = product of components; each component a prime power
   int[] primePowerExponents; // array of exponents of the primes
   int primeCount; // number of prime divisors of m

   SumModM() {
      primes = new int[10]; // 10 is plenty since m<=200000
      primePowers = new int[10];
      primePowerExponents = new int[10];
   }
   
   public void set(int m,long n) {
      this.m=m;
      this.n=n;
      primeCount=0;
      factorM();
      nIsBig = false;
   }
   
   public void set(long n) {
      this.n = n;
      nIsBig = false;
   }
   
   public void set(BigInteger n) {
      N = n;
      nIsBig = true;
   }
   
/* crude factoring method of m, but fine for "small" m */   
   private void factorM() {
      int n = m, d;
      if ((n & 1) == 0) { // 2 divides n
         primePowers[0] = primes[0] = 2;
         primePowerExponents[0]=1;
         n >>= 1;
         while ((n & 1) == 0) {
            primePowers[0] *= 2;
            primePowerExponents[0]++;
            n >>= 1;
         }
         primeCount = 1;
      }
      d = 3;
      while (d <= n / d) {
         if (n % d == 0) {
            primes[primeCount] = primePowers[primeCount] = d;
            primePowerExponents[primeCount] = 1;
            n /= d;
            while (n % d == 0) {
               primePowers[primeCount] *= d;
               primePowerExponents[primeCount]++;
               n /= d;
            }
            primeCount++;
         }
         d += 2;
      }
      if (n != 1) {
         primes[primeCount] = primePowers[primeCount] = n;
         primePowerExponents[primeCount] = 1;
         primeCount++;
      }
   }
   
/* Computes x^m mod n.  Notice the method requires m*m <= max value for long , namely 2^63-1,
 * so it's perfectly safe for m an int..  This * recursive version is essentially just as efficient as a
 * non-recursive version.
 */
   private long powMod(long x, long m, long n) {
      long result;
      if (m == 1) {
         return (x);
      }
      result = powMod(x, m >> 1, n);
      result = (result * result) % n;
      if ((m & 1) == 1) {
         result = (x * result) % n;
      }
      return (result);
   }
   
   int computeSum() {
      int m1 = computeSum(0), modulus1 = primePowers[0];
      int i, m2;
      for (i = 1; i < primeCount; i++) {
         m2 = computeSum(i);
         m1 = chineseRemainder(m1, modulus1, m2, primePowers[i]);
         modulus1 *= primePowers[i];
      }
      return (m1);
   }

   int computeSum(int i) {
      int p = primes[i];
      m = primePowers[i];
      int j, sum = 0;
      long term;
      for (j = 1; j < m; j++) {
         if (j%p==0) {  // try to save a few milliseconds
            if (primePowerExponents[i]<=j) {
               continue;
            }
         }
         term = powMod(j, j, m);
         term = (term * computeSum_j(j, p, m)) % m;
         sum = (int) ((sum + term) % m);
      }
      return (sum);
   }

   int computeSum_j(int j, int p, int m) {
      if (j % p == 0) {
         return (1);
      }
      int x;
      long result, r, inverse;
      if (!nIsBig) {
         long n_j = (n - j) / m;
         if (p == 2) {
            return ((int) ((n_j + 1) % m));
         }
         x = (int) powMod(j, m / p, m);
         if (x == 1) {
            return ((int) ((n_j + 1) % m));
         }
         result = 0;
         r = (n_j + 1) % (p - 1);
         if (r != 0) {
            inverse = aInverseModB(x - 1, m);
            result = ((powMod(x, r, m) - 1) * inverse) % m;
         }
         return ((int) result);
      }
      N_j = N.subtract(BigInteger.valueOf(j));
      N_j = N_j.divide(BigInteger.valueOf(m));
      N_j = N_j.add(BigInteger.ONE);
      if (p == 2) {
         R = N_j.mod(BigInteger.valueOf(m));
         return (R.intValue());
      }
      x = (int) powMod(j, m / p, m);
      if (x == 1) {
         R = N_j.mod(BigInteger.valueOf(m));
         return (R.intValue());
      }
      result = 0;
      R = N_j.mod(BigInteger.valueOf(p - 1));
      if (!R.equals(BigInteger.ZERO)) {
         inverse = aInverseModB(x - 1, m);
         r = R.intValue();
         result = ((powMod(x, r, m) - 1) * inverse) % m;
      }
      return ((int) result);
   }
/* Upon entry m and n are positive coprime ints and a, b are non-negative ints reduced mod m and mod n
 * respectively.  The return is x with 0<=x and x<m*n  where x is congruent to a mod m and x is congruent
 * to b mod n.  Here's how it is done:
 * Clearly a+km is congruent to a for any int k;  so just need a k with a+km congruent to b mod n.  That is,
 * k should be congruent to (b-a)*(inverse of m mod n).  Then easily with 0<=k<n, x is of the correct size.
 */
   int chineseRemainder(int a, int m, int b, int n) {
      int k = (b - a) * aInverseModB(m, n);
      k %= n;
      if (k < 0) {
         k += n;
      }
      return (a + k * m);
   }

   /* Computation for small n's.  Check for the result of the above algorithm. */
   int feasible() {
      int j;
      int sum = 1;
      long term;
      for (j = 2; j <= n; j++) {
         term = powMod(j, j, m);
         sum = (int) ((sum + term) % m);
      }
      return (sum);
   }
   /*  "Standard computation of a inverse mod b where a is prime to b. */
   int aInverseModB(int a, int b) {
      int a1 = a;
      int b1 = b;
      int m0 = 1,  m1 = 0,  r, q, m2;
      while (b1 != 0) {
         r = a1 % b1;
         q = a1 / b1;
         m2 = m0 - m1 * q;
         m0 = m1;
         m1 = m2;
         a1 = b1;
         b1 = r;
      }
      if (m0 < 0) {
         m0 = b + m0;
      }
      return (m0);
   }
}
Finally, here's some timings for a couple of runs of the program:
For m=10^5 and n=i*10^18 for i=1 to 6:
62500
62500
62500
62500
62500
62500
time 16 milliseconds.
For m=10^5 and n=10^100, a googol:
62500
time 16 milliseconds.

You might want to verify for yourself that for m=10^5 and n any value with m^2 dividing n, you always get an answer of 62500. So m=100000 is an easy case. Notice the time for a googol is still very fast.

The 10,000th prime is p = 104729
For m = p and n=i*10^18 for i=1 to 6:
52444
39642
378
64098
7399
11003
time 936 milliseconds.
For m=p and n=10^100, a googol:
23255
time 312 milliseconds.
 
Thanks a lot johng! That is exactly what I was looking for, great approach and explanation! :)
 

Similar threads

Back
Top