1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
| ##include "bits/stdc++.h" using namespace std;
typedef long long ll;
const ll mod = 998244353;
const int N = 1e6 + 10;
namespace polysum { ##define rep(i,a,n) for (int i=a;i<n;i++) ##define per(i,a,n) for (int i=n-1;i>=a;i--) const int D=1010000; ll a[D],f[D],g[D],p[D],p1[D],p2[D],b[D],h[D][2],C[D], num[D]; ll powmod(ll a,ll b){ll res=1;a%=mod;assert(b>=0);for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll calcn(int d,ll *a,ll n) { if (n<=d) return a[n]; p1[0]=p2[0]=1; rep(i,0,d+1) { ll t=(n-i+mod)%mod; p1[i+1]=p1[i]*t%mod; } rep(i,0,d+1) { ll t=(n-d+i+mod)%mod; p2[i+1]=p2[i]*t%mod; } ll ans=0; rep(i,0,d+1) { ll t=g[i]*g[d-i]%mod*p1[i]%mod*p2[d-i]%mod*a[i]%mod; if ((d-i)&1) ans=(ans-t+mod)%mod; else ans=(ans+t)%mod; } return ans; } void init(int M) { f[0]=f[1]=g[0]=g[1]=1; rep(i,2,M+5) f[i]=f[i-1]*i%mod; g[M+4]=powmod(f[M+4],mod-2); per(i,1,M+4) g[i]=g[i+1]*(i+1)%mod; }
ll polysum(ll m,ll *a,ll n) { for(int i=0;i<=m;i++) b[i]=a[i];
b[m+1]=calcn(m,b,m+1); rep(i,1,m+2) b[i]=(b[i-1]+b[i])%mod; return calcn(m+1,b,n-1); }
ll solve(ll n, int k) { ll ans = polysum(k + 1, num, n) % mod; return ans; } }
bool is_prime[N]; int prime[N], cnt, mu[N]; ll n, k;
void init() { mu[1] = 1; for(int i = 2;i < N; i++) { if(!is_prime[i]) prime[++cnt] = i, mu[i] = -1; for(int j = 1;j <= cnt && i * prime[j] < N; j++) { is_prime[i * prime[j]] = 1; if(i % prime[j] == 0) { mu[i * prime[j]] = 0; break; } else mu[i * prime[j]] = -mu[i]; } } for(int i = 2;i < N; i++) mu[i] += mu[i - 1];
polysum::init(k + 10); for(int i = 0;i <= k + 1; i++) polysum::num[i] = polysum::powmod((ll)i + 1, k); }
map<ll, ll> mp; ll S(ll x) { if(x < N) return mu[x]; if(mp[x]) return mp[x]; ll ans = 1; for(int l = 2, r;l <= x; l = r + 1) { r = min(x, x / (x / l)); ans = ans - (r - l + 1) * S(x / l); } return mp[x] = ans; }
void solve() { cin >> n >> k; init(); ll ans = 0; for(ll l = 1, r;l <= n; l = r + 1) { r = min(n, n / (n / l)); ans = (ans + (S(r) - S(l - 1)) * polysum::solve(n / l, k) % mod) % mod; } cout << (ans % mod + mod) % mod << endl; }
signed main() { solve(); }
|