洛谷题单指南-树与图上的动态规划-P7077 [CSP-S2020] 函数调用
原题链接:https://www.luogu.com.cn/problem/P7077
题意解读:一个整数序列,有三种函数:1、做单点加法 2、做所有数乘法 3、调用其他函数,给出Q个函数调用按顺序执行,输出整数序列的结果。
解题思路:
1、朴素想法
直接做,对序列的修改借助于线段树,总体复杂度在O(NMlogN),N在10^5,M在10^6,显然无法通过。
2、问题分析
既然不能直接做,换一个角度,可以求得每一个函数对答案的贡献是什么!
假设只有1类函数、2类函数的情况,
对于2类函数,所有乘数的积t,就是对原来数据的影响
对于1类函数,要看后面有多少个2类函数,假设1类函数是对a[x]加2,后面所有2类函数的乘数之积是times,那么这个1类函数的最终效果是a[x] + 2 times,算上2类函数a[x]变成了a[x] t + 2 * times
3、进一步
有了上面分析,问题的关键在于求的每一个1类函数后面有多少个乘数之积,
对于3类函数,同样可以通过逆拓扑序递推出其乘数,因为对于一个节点u,乘数为所有子节点乘数之积。
4、漏掉了什么?
通过上面计算,可以算出q个函数的调用次数,但是有可能存在1类函数在3类函数中,这样就要把3类函数的
调用次数下传到1类函数,并且3类函数下如果有1 2 2这样的函数调用,这里1类函数的次数还要算上后面两个2类函数的乘数之积,并与3类函数调用次数相乘,才得到该1类函数的调用次数。
这里,可以通过对3类函数按拓扑序进行递推,将调用次数下传并累加至所有1类函数。
100分代码:
#include <bits/stdc++.h>
using namespace std;
const int N = 100005, M = 1000005, Q = 100005, MOD = 998244353;
int a[N];
struct Func
{
int t, p, v;
long long mul, times;
} f[N];
int head[N], to[M], nxt[M], idx;
int in[N]; //入度
int g[N];
int n, m, q;
void add(int a, int b)
{
to[++idx] = b;
nxt[idx] = head[a];
head[a] = idx;
in[b]++;
}
int qq[N], hh = 0, tt = -1;
void topsort()
{
for(int i = 1; i <= m; i++)
{
if(!in[i]) qq[++tt] = i;
}
while(hh <= tt)
{
int u = qq[hh++];
for(int i = head[u]; i; i = nxt[i])
{
int v = to[i];
if(--in[v] == 0) qq[++tt] = v;
}
}
}
void calc_mul()
{
for(int i = m - 1; i >= 0; i--)
{
int u = qq[i];
for(int j = head[u]; j; j = nxt[j])
{
int v = to[j];
f[u].mul = f[u].mul * f[v].mul % MOD;
}
}
}
void calc_times()
{
for(int i = 0; i < m; i++)
{
int u = qq[i];
long long t = f[u].times;
for(int j = head[u]; j; j = nxt[j])
{
int v = to[j];
f[v].times = (f[v].times + t) % MOD;
t = t * f[v].mul % MOD;
}
}
}
int main()
{
cin >> n;
for(int i = 1; i <= n; i++) cin >> a[i];
cin >> m;
for(int i = 1; i <= m; i++)
{
cin >> f[i].t;
if(f[i].t == 1)
{
cin >> f[i].p >> f[i].v;
f[i].mul = 1;
}
else if(f[i].t == 2)
{
cin >> f[i].v;
f[i].mul = f[i].v;
}
else
{
f[i].mul = 1;
int cnt;
cin >> cnt;
while(cnt--)
{
int x; cin >> x;
add(i, x);
}
}
}
topsort(); //先拓扑排序
calc_mul(); //计算每个节点的mul值
cin >> q;
for(int i = 1; i <= q; i++) cin >> g[i];
long long t = 1;
for(int i = q; i >= 1; i--) //从后往前计算,每个类型1/3函数执行了多少次
{
int x = g[i];
f[x].times = (f[x].times + t) % MOD; //同一个函数可能多次调用,要累加次数
t = t * f[x].mul % MOD;
}
calc_times(); //将每个函数的times值传递到其子节点
for(int i = 1; i <= n; i++) a[i] = a[i] * t % MOD; //将每个节点的值乘以t,也就是所有的乘法操作
for(int i = 1; i <= m; i++)
{
if(f[i].t == 1) a[f[i].p] = (a[f[i].p] + f[i].v * f[i].times % MOD) % MOD; //类型1函数
}
for(int i = 1; i <= n; i++) cout << a[i] << " "; //输出结果
return 0;
}