yukicoder No.681 Fractal Gravity Glue

No.681 Fractal Gravity Glue - yukicoder

解法

頑張って解いたので解法を書きます。

まず、以下の関数を定義します。

  • count(i,d)

G_{i,d} の段数を返す

  • sum\_n(d,n)

n 段目までの石の大きさの総和を返す

  • sum(i,d)

G_{i,d} の石の大きさの総和を返す

実装の都合上、sum\_n の引数には count(i,d) \gt n を満たす最も小さな i も含め、

 sum\_n(i,d,n) のように呼び出すことにします。  

i1 から順に増やしていき、count(i,d) \gt n が成り立ったら sum(b,d)-sum\_n(i,d,n) を出力する、という方針で解きます。 以下、各関数の実装です。

count(i,d)

 この関数では、i の値はそれほど大きくならないので、定義通りの再帰関数を書きます。

long long count(long long i,long long d)
{
    if(i<1)return 0;
    else if(i==1)return d;
    else return d+count(i-1,d)*(d+1);
}

sum\_n(i,d,n)

s := count(i-1,d) とすると、石の塔は 1,s,1,s,...,s,1 と重なります。 したがって、t := floor(n/(1+s))1,s の大きさの和を掛けて t \times (i+sum(i-1,d)) 、残りを再帰的に sum\_n(i-1,d,n-t \times (1+s)) として足し合わせれば求まります。
n = 0 を終了条件として、0 を返しておきます。

long long sum_n(long long i,long long d,long long n)
{
    if(n==0)return 0;
    long long s=count(i-1,d);
    long long t=n/(1+s);
    return (t*(i+sum(i-1,d))%mod+sum_n(i-1,d,n-t*(1+s)))%mod;
}

sum(i,d)

count と同じように再帰的に求めようにも、この関数については i にわたる数が大きくなりすぎる可能性がありますから、そこを何とかする必要があります。 とりあえず、再帰を書いてみます。

long long sum(long long i,long long d)
{
    if(i<1)return 0;
    else if(i==1)return i*d%mod;
    else return (i*d%mod+sum(i-1,d)*(d+1))%mod;
}

考えてみると、これはforループで以下のように書けます。

long long sum(long long i,long long d)
{
    long long ans=0;
    for(long long j=1;j<=i;j++)
    {
        ans=(j*d%mod+ans*(d+1)%mod)%mod;
    }
    return ans;
}

さて、分解していくと以下のようになります。

i \times d+( (i-1) \times d +(...+(2 \times d+(1 \times d) \times (d+1)) \times (d+1)...) \times (d+1)) \times (d+1)

1,2,...,i について d が一回、(d+1) が複数回掛けられ、それの総和のようです。
変形して

\sum_{j=1}^{i}j \times d \times (d+1)^{i-j}

=\frac{(d+1) \times ((d+1)^{i}-1)}{d}-i

これなら時間的にも全く問題ありませんね!

long long sum(long long i,long long d)
{
    return((d+1)*(pow(d+1,i)+mod-1)%mod*pow(d,mod-2)%mod-i+mod)%mod;
}

これですべての関数が完成し、ACすることができました。 ソースコードの全体を下に載せておきます。

#include<iostream>
using namespace std;
long long mod=1e9+7;
long long count(long long i,long long d)
{
    if(i<1)return 0;
    else if(i==1)return d;
    else return d+count(i-1,d)*(d+1);
}
long long pow(long long a,long long b)
{
    return b?pow(a*a%mod,b/2)*(b%2?a:1)%mod:1;
}
long long sum(long long i,long long d)
{
    return((d+1)*(pow(d+1,i)+mod-1)%mod*pow(d,mod-2)%mod-i+mod)%mod;
}
long long sum_n(long long i,long long d,long long n)
{
    if(n==0)return 0;
    long long s=count(i-1,d);
    long long t=n/(1+s);
    return (t*(i+sum(i-1,d))%mod+sum_n(i-1,d,n-t*(1+s)))%mod;
}
main()
{
    long long n,b,d;
    cin>>n>>b>>d;
    for(long long i=1;;i++)
    {
        if(count(i,d)>n)
        {
            cout<<(sum(b,d)+mod-sum_n(i,d,n))%mod<<endl;
            return 0;
        }
    }
}