package bigsummod;
import java.math.BigInteger;
/**
 *
 * @author John on 7/6/2015
 */
public class BigSumMod {
   public static void main(String[] args) {
      SumModM test = new SumModM();
      long n = 1;
      int i;
      for (i = 1; i <= 18; i++) {
         n *= 10;
      }
      long time1 = System.currentTimeMillis(), time2;
      test.set(104729, n);
      System.out.println("The 10,000th prime is p = 104729");
      System.out.println("For m = p and n=i*10^18 for i=1 to 6:");
      for (i = 1; i <= 6; i++) {
         test.set(i * n);
         System.out.println(test.computeSum());
      }
      System.out.println("time " + (System.currentTimeMillis() - time1)+" milliseconds.");
      BigInteger N = BigInteger.TEN;
      N = N.pow(100); // N is a googol
      System.err.println("For m=p and n=10^100, a googol:");
      time1 = System.currentTimeMillis();
      test.set(N);
      System.out.println(test.computeSum());
      System.out.println("time " + (System.currentTimeMillis() - time1)+" milliseconds.");      
   }
}
package bigsummod;
/**
 * Class to compute sum(i = 1 to n)i^i mod m.  Here m<=200000 and n is "large", say 10^18
 * or even bigger than a long (2^63-1), say the "big" integer 10^100.
 */
import java.math.BigInteger;
public class SumModM {
   int m;
   long n;
   BigInteger N, N_j, R;
   boolean nIsBig;
   int[] primes; // array of prime divisors of m
   int[] primePowers; // m = product of components; each component a prime power
   int[] primePowerExponents; // array of exponents of the primes
   int primeCount; // number of prime divisors of m
   SumModM() {
      primes = new int[10]; // 10 is plenty since m<=200000
      primePowers = new int[10];
      primePowerExponents = new int[10];
   }
   
   public void set(int m,long n) {
      this.m=m;
      this.n=n;
      primeCount=0;
      factorM();
      nIsBig = false;
   }
   
   public void set(long n) {
      this.n = n;
      nIsBig = false;
   }
   
   public void set(BigInteger n) {
      N = n;
      nIsBig = true;
   }
   
/* crude factoring method of m, but fine for "small" m */   
   private void factorM() {
      int n = m, d;
      if ((n & 1) == 0) { // 2 divides n
         primePowers[0] = primes[0] = 2;
         primePowerExponents[0]=1;
         n >>= 1;
         while ((n & 1) == 0) {
            primePowers[0] *= 2;
            primePowerExponents[0]++;
            n >>= 1;
         }
         primeCount = 1;
      }
      d = 3;
      while (d <= n / d) {
         if (n % d == 0) {
            primes[primeCount] = primePowers[primeCount] = d;
            primePowerExponents[primeCount] = 1;
            n /= d;
            while (n % d == 0) {
               primePowers[primeCount] *= d;
               primePowerExponents[primeCount]++;
               n /= d;
            }
            primeCount++;
         }
         d += 2;
      }
      if (n != 1) {
         primes[primeCount] = primePowers[primeCount] = n;
         primePowerExponents[primeCount] = 1;
         primeCount++;
      }
   }
   
/* Computes x^m mod n.  Notice the method requires m*m <= max value for long , namely 2^63-1,
 * so it's perfectly safe for m an int..  This * recursive version is essentially just as efficient as a
 * non-recursive version.
 */
   private long powMod(long x, long m, long n) {
      long result;
      if (m == 1) {
         return (x);
      }
      result = powMod(x, m >> 1, n);
      result = (result * result) % n;
      if ((m & 1) == 1) {
         result = (x * result) % n;
      }
      return (result);
   }
   
   int computeSum() {
      int m1 = computeSum(0), modulus1 = primePowers[0];
      int i, m2;
      for (i = 1; i < primeCount; i++) {
         m2 = computeSum(i);
         m1 = chineseRemainder(m1, modulus1, m2, primePowers[i]);
         modulus1 *= primePowers[i];
      }
      return (m1);
   }
   int computeSum(int i) {
      int p = primes[i];
      m = primePowers[i];
      int j, sum = 0;
      long term;
      for (j = 1; j < m; j++) {
         if (j%p==0) {  // try to save a few milliseconds
            if (primePowerExponents[i]<=j) {
               continue;
            }
         }
         term = powMod(j, j, m);
         term = (term * computeSum_j(j, p, m)) % m;
         sum = (int) ((sum + term) % m);
      }
      return (sum);
   }
   int computeSum_j(int j, int p, int m) {
      if (j % p == 0) {
         return (1);
      }
      int x;
      long result, r, inverse;
      if (!nIsBig) {
         long n_j = (n - j) / m;
         if (p == 2) {
            return ((int) ((n_j + 1) % m));
         }
         x = (int) powMod(j, m / p, m);
         if (x == 1) {
            return ((int) ((n_j + 1) % m));
         }
         result = 0;
         r = (n_j + 1) % (p - 1);
         if (r != 0) {
            inverse = aInverseModB(x - 1, m);
            result = ((powMod(x, r, m) - 1) * inverse) % m;
         }
         return ((int) result);
      }
      N_j = N.subtract(BigInteger.valueOf(j));
      N_j = N_j.divide(BigInteger.valueOf(m));
      N_j = N_j.add(BigInteger.ONE);
      if (p == 2) {
         R = N_j.mod(BigInteger.valueOf(m));
         return (R.intValue());
      }
      x = (int) powMod(j, m / p, m);
      if (x == 1) {
         R = N_j.mod(BigInteger.valueOf(m));
         return (R.intValue());
      }
      result = 0;
      R = N_j.mod(BigInteger.valueOf(p - 1));
      if (!R.equals(BigInteger.ZERO)) {
         inverse = aInverseModB(x - 1, m);
         r = R.intValue();
         result = ((powMod(x, r, m) - 1) * inverse) % m;
      }
      return ((int) result);
   }
/* Upon entry m and n are positive coprime ints and a, b are non-negative ints reduced mod m and mod n
 * respectively.  The return is x with 0<=x and x<m*n  where x is congruent to a mod m and x is congruent
 * to b mod n.  Here's how it is done:
 * Clearly a+km is congruent to a for any int k;  so just need a k with a+km congruent to b mod n.  That is,
 * k should be congruent to (b-a)*(inverse of m mod n).  Then easily with 0<=k<n, x is of the correct size.
 */
   int chineseRemainder(int a, int m, int b, int n) {
      int k = (b - a) * aInverseModB(m, n);
      k %= n;
      if (k < 0) {
         k += n;
      }
      return (a + k * m);
   }
   /* Computation for small n's.  Check for the result of the above algorithm. */
   int feasible() {
      int j;
      int sum = 1;
      long term;
      for (j = 2; j <= n; j++) {
         term = powMod(j, j, m);
         sum = (int) ((sum + term) % m);
      }
      return (sum);
   }
   /*  "Standard computation of a inverse mod b where a is prime to b. */
   int aInverseModB(int a, int b) {
      int a1 = a;
      int b1 = b;
      int m0 = 1,  m1 = 0,  r, q, m2;
      while (b1 != 0) {
         r = a1 % b1;
         q = a1 / b1;
         m2 = m0 - m1 * q;
         m0 = m1;
         m1 = m2;
         a1 = b1;
         b1 = r;
      }
      if (m0 < 0) {
         m0 = b + m0;
      }
      return (m0);
   }
}