Library

This documentation is automatically generated by competitive-verifier/competitive-verifier

View the Project on GitHub anmichi/Library

:heavy_check_mark: mod_sqrt.cpp

Depends on

Required by

Verified with

Code

#include "template.cpp"
int64_t mod_sqrt(const int64_t& a, const int64_t& p) {
    assert(0 <= a && a < p);
    if (a < 2) return a;
    if (modpow(a, (p - 1) >> 1, p) != 1) return -1;
    int64_t q = p - 1, m = 0;
    while (!(q & 1)) {
        q >>= 1;
        m++;
    }
    int64_t z = 1;
    while (modpow(z, (p - 1) >> 1, p) == 1) z++;
    int64_t c = modpow(z, q, p);
    int64_t t = modpow(a, q, p);
    int64_t r = modpow(a, (q + 1) >> 1, p);
    if (t == 0) return 0;
    m -= 2;
    while (t != 1) {
        while (modpow(t, int64_t(1) << m, p) == 1) {
            c = c * c % p;
            m--;
        }
        r = r * c % p;
        c = c * c % p;
        t = t * c % p;
        m--;
    }
    return r;
}
#line 1 "template.cpp"
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
using ll = long long;
template <class T>
using pbds_set = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
template <class T>
using pbds_mset = tree<T, null_type, less_equal<T>, rb_tree_tag, tree_order_statistics_node_update>;
using pbds_trie = trie<string, null_type, trie_string_access_traits<>, pat_trie_tag, trie_prefix_search_node_update>;
#define rep(i, n) for (int i = 0; i < n; i++)
#define all(v) v.begin(), v.end()
template <class T, class U>
inline bool chmax(T &a, U b) {
    if (a < b) {
        a = b;
        return true;
    }
    return false;
}
template <class T, class U>
inline bool chmin(T &a, U b) {
    if (a > b) {
        a = b;
        return true;
    }
    return false;
}
constexpr int INF = 1000000000;
constexpr int64_t llINF = 3000000000000000000;
constexpr double eps = 1e-10;
const double pi = acos(-1);
template <class T>
inline void compress(vector<T> &a) {
    sort(a.begin(), a.end());
    a.erase(unique(a.begin(), a.end()), a.end());
}
struct linear_sieve {
    vector<int> least_factor, prime_list;
    linear_sieve(int n) : least_factor(n + 1, 0) {
        for (int i = 2; i <= n; i++) {
            if (least_factor[i] == 0) {
                least_factor[i] = i;
                prime_list.push_back(i);
            }
            for (int p : prime_list) {
                if (ll(i) * p > n || p > least_factor[i]) break;
                least_factor[i * p] = p;
            }
        }
    }
};
ll extgcd(ll a, ll b, ll &x, ll &y) {
    // ax+by=gcd(|a|,|b|)
    if (a < 0 || b < 0) {
        ll d = extgcd(abs(a), abs(b), x, y);
        if (a < 0) x = -x;
        if (b < 0) y = -y;
        return d;
    }
    if (b == 0) {
        x = 1;
        y = 0;
        return a;
    }
    ll d = extgcd(b, a % b, y, x);
    y -= a / b * x;
    return d;
}
ll modpow(ll a, ll b, ll m) {
    ll res = 1;
    while (b) {
        if (b & 1) {
            res *= a;
            res %= m;
        }
        a *= a;
        a %= m;
        b >>= 1;
    }
    return res;
}
template <typename T, typename U>
inline istream &operator>>(istream &is, pair<T, U> &rhs) {
    return is >> rhs.first >> rhs.second;
}
template <typename T>
inline istream &operator>>(istream &is, vector<T> &v) {
    for (auto &e : v) is >> e;
    return is;
}
template <typename T, typename U>
inline ostream &operator<<(ostream &os, const pair<T, U> &rhs) {
    return os << rhs.first << " " << rhs.second;
}
template <typename T>
inline ostream &operator<<(ostream &os, const vector<T> &v) {
    for (auto itr = v.begin(), end_itr = v.end(); itr != end_itr;) {
        os << *itr;
        if (++itr != end_itr) os << " ";
    }
    return os;
}
#line 2 "mod_sqrt.cpp"
int64_t mod_sqrt(const int64_t& a, const int64_t& p) {
    assert(0 <= a && a < p);
    if (a < 2) return a;
    if (modpow(a, (p - 1) >> 1, p) != 1) return -1;
    int64_t q = p - 1, m = 0;
    while (!(q & 1)) {
        q >>= 1;
        m++;
    }
    int64_t z = 1;
    while (modpow(z, (p - 1) >> 1, p) == 1) z++;
    int64_t c = modpow(z, q, p);
    int64_t t = modpow(a, q, p);
    int64_t r = modpow(a, (q + 1) >> 1, p);
    if (t == 0) return 0;
    m -= 2;
    while (t != 1) {
        while (modpow(t, int64_t(1) << m, p) == 1) {
            c = c * c % p;
            m--;
        }
        r = r * c % p;
        c = c * c % p;
        t = t * c % p;
        m--;
    }
    return r;
}
Back to top page