Dijkstra的Algo使用STL实现

时间:2016-10-29 04:53:15

标签: c++ algorithm graph stl dijkstra

#include<bits/stdc++.h>
using namespace std;
#define ll long long
vector<pair<ll,ll> >v[100005];
ll dis[100005];
bool visited[100005];
multiset<pair <ll,ll> > s;
int main(){
    ll n,m,from,next,weight,i;
    cin>>n>>m;
    for(i=1;i<=n;i++){
        v[i].clear();
        dis[i]=2e9;
    }
    for(i=1;i<=m;i++){
        cin>>from>>next>>weight;
        v[from].push_back(make_pair(next,weight));
        v[next].push_back(make_pair(from,weight));
    }
    dis[1]=0;
    s.insert({0,1});
    memset(visited,false,sizeof(visited));
    while(!s.empty()){
        pair<ll,ll>p= *s.begin();
        s.erase(s.begin());
        ll x=p.second;
        ll wei=p.first;
        if(visited[x]) continue;
        for(i=0;i<v[x].size();i++){
            ll e=v[x][i].first;
            ll w=v[x][i].second;
            if(dis[x]+w < dis[e]){
                dis[e]=dis[x]+w;
                s.insert({dis[e],e});
            }
        }
    }
    for(i=2;i<=m;i++)
     cout<<dis[i]<<" ";
}

我有Dijkstra Algo的c ++实现,但我想这对所有情况都不适用(更大的测试用例)。谁能帮我解决这个问题。我是否遗漏了某些内容或者没有实施过。 代码输出每个顶点与源顶点的最小距离(即1)。

2 个答案:

答案 0 :(得分:1)

您永远不会写入visited数组。因此可能会多次扫描边缘。简单修复:在if(visited[x]) continue;

之后添加一行
visited[x] = true;

答案 1 :(得分:0)

这是我在O(N)图中求解的解决方案:            #包括         #包括         #include

     typedef long long ll;
    void fs_int(int *x) {
        register int c = getchar_unlocked();
        *x = 0;
        int neg = 0;

        for(; ((c<48 || c>57) && c != '-'); c = getchar_unlocked());

        if(c=='-') {
            neg = 1;
            c = getchar_unlocked();
        }

        for(; c>47 && c<58 ; c = getchar_unlocked()) {
            *x = (*x<<1) + (*x<<3) + c - 48;
        }

        if(neg)
            *x = -(*x);
    } 
    typedef struct {
        int next;
        int val;
        int d;

    }List;
    typedef struct 
    {
        int parent;
        int shrt;
        ll count;
        int on_reg;
        int ch;
    } Node;
    #define MOD 1000000007
    ll get_sum(Node *tr,List *l)
    {
        Node *t, *t2;
        int i,j,n=0,fix;
        ll result;
        static int *reg=NULL,sz=1000;

        if (!reg)
            reg=malloc(sizeof(int)*sz);
        reg[n++]=1;
        int  cur_d;

        while(n)
        {
            ///fix is the limit for the for, it is the shortname of "for ix" :
            // from 0 to fix there are the old values, from fix to n there are the new ones
            fix=n;   
            for (i=0;i<fix;i++)
            {

               //the better way to reduce the complexity is shift the last item to the current one
                t=&tr[reg[i]];
                reg[i--]=reg[--fix];
                reg[fix]=reg[--n];
                t->on_reg=0;

                ///this scores all the edges from departing from this node
                ///the criteria to avoid propagation is the key of the program
                for (j=t->ch;j;j=l[j].next)
                {   
                    if (l[j].val==1) //avoid the root
                        continue;

                    t2=&tr[l[j].val]; //store in some comfortable variable

                    cur_d=t->shrt+l[j].d; 

                    if (t2->shrt!=0 && t2->shrt< cur_d ) ///if my path is heaviest nothing to do
                        continue;
                    else if (t2->shrt ==cur_d) //I found an item with same weight. It was required to count them
                        t2->count++;
                    else if (t2->shrt==0 || t2->shrt>cur_d) //found a unexplored item or my path is lighter
                    {
                        t2->shrt=cur_d;
                        t2->count=1;
                        if (!t2->on_reg) //if not already in the reg, I insert it inside
                        {
                            if (n>=sz)
                            {
                                sz<<=1;
                                reg=realloc(reg, sizeof(int)*sz);
                            }
                            reg[n++]=l[j].val; //at position n
                            t2->on_reg=1;
                        }
                    }

                }
           /* printf ("reg: ");
            for (k=0;k<n;k++)
                printf ("%d ",reg[k]);
                printf ("\n");*/
            }
        }

        //printf ("\n");
        return result;


    }

    typedef long long ll;
    void set_depth(Node *tr, List *l, int rt,int cd,int parent)
    {
        int i;

        tr[rt].parent=parent;
        for (i=tr[rt].ch;i;i=l[i].next)
            if (l[i].val== parent )
                continue;
            else 
                set_depth(tr,l,l[i].val,cd+1,rt);
    }

    int main ()
    {

        int t,n,q,i,u,v,d;
        fs_int(&t);
        int il=1;
        Node tr[100005];
        List l[200005];
        List *tl;
        while (t--)
        {
            fs_int(&n);
            fs_int(&q);
            il=1;

            memset(tr,0,sizeof(tr));
            memset(l,0,sizeof(l));

            for (i=0;i<q;i++)
            {
                fs_int(&u);
                fs_int(&v);
                fs_int(&d);

                tl=&l[il];
                tl->next=tr[u].ch;
                tl->val=v;
                tl->d=d;
                tr[u].ch=il++;


                tl=&l[il];
                tl->next=tr[v].ch;
                tl->val=u;
                tl->d=d;
                tr[v].ch=il++;

            }

           //set_depth(tr,l,1,0,0);
           // print(tr,l,1,0,0);

           get_sum(tr,l);

           ll res=1;
            for (i=2;i<=n;i++)
            {


                res= ( (res%MOD) *(tr[i].count%MOD) )%MOD;
            }   
            printf ("%lld\n",res);
        }

        return 0;
    }

您感兴趣的功能是函数get_sum()。这是一个广泛的第一次搜索,在图表中意味着沿着同心圆检查,这允许您避免无用的传播。它将值存储在一个名为reg的数组中的虚拟圆圈中。在每一步,你都要检查。关于您可以在Ways比赛中自行检查的效率。它有一个最好的时间