Cod sursa(job #2941138)

Utilizator TeodoraMaria123Serban Teodora Maria TeodoraMaria123 Data 17 noiembrie 2022 10:37:16
Problema Ubuntzei Scor 100
Compilator cpp-64 Status done
Runda Arhiva de probleme Marime 4.2 kb
#include <fstream>
#include <vector>
#include <queue>
#include <bitset>
#include <cstring>
#include <cassert>

using namespace std;

ifstream in("ubuntzei.in");
ofstream out("ubuntzei.out");

#define MAX_N 2000
#define MAX_K 15
#define MAX_STATE 32768  ///2^15
#define INF 1000000000

int n, m, k, minimumDist = INF;
int dp[MAX_K][MAX_STATE];
vector < vector <pair <int, int> > > graph;
vector <vector <int> > dist;
priority_queue < pair <int, int> > pq;
vector <int> friends;
bitset <MAX_N> visited;

void readGraph()
{
    in >> n >> m >> k;
    friends.resize(k);
    graph.resize(n);
    dist.resize(k);
    for(int i = 0; i < k; i ++)
    {
        int x;
        in >> x;
        x --;
        assert(i < k);
        friends[i] = x;
    }
    for(int i = 0; i < m; i ++)
    {
        int a, b, c;
        in >> a >> b >> c;
        a--;
        b--;
        assert(a < n  &&  b < n);
        graph[a].push_back({b, c});
        graph[b].push_back({a, c});
    }
}

void dijkstra(vector <int> &distanceForANode, int source)
{
    distanceForANode.resize(n, INF);
    pq.push({0,source});
    assert(source < n);
    distanceForANode[source] = 0;
    while(!pq.empty())
    {
        int node = pq.top().second;
        pq.pop();
        assert(node < n);
        if(!visited[node])
        {
            visited[node] = 1;
            for(auto neighbour : graph[node])
            {
                assert(neighbour.first < n  &&  node < n);
                if(!visited[neighbour.first]  &&  distanceForANode[neighbour.first] > distanceForANode[node] + neighbour.second)
                {
                    distanceForANode[neighbour.first] = distanceForANode[node] + neighbour.second;
                    pq.push({-distanceForANode[neighbour.first], neighbour.first});
                }
            }
        }
    }
}

void printGraph()
{
    for(int i = 0; i < n; i++)
    {
        out << i + 1 << ": ";
        for(auto x : graph[i])
            out << x.first + 1 << " " << x.second << "\n";
        out << "\n";
    }
}

void solve()
{
    dist.resize(k);
    for(int i = 0; i < k; i ++)
    {
        dijkstra(dist[i], friends[i]);
        for(int j = 0; j < n ; j++)
            visited[j] = 0;
    }

    if(k == 0)
    {
        dist.resize(1);
        dijkstra(dist[0], 0);
        out << dist[0][n - 1] << "\n";
        return;
    }

    int stateNo, initialMask;
    stateNo = (1 << k) - 1;
    initialMask = 1;

    for(int i = 0; i < k; i++)
        for(int mask = 0; mask <= stateNo; mask ++)
            dp[i][mask] = INF;

    for(int i = 0; i < k; i++)
        dp[i][1 << i] = dist[i][0];

    for(int mask = initialMask; mask <= stateNo; mask ++)
    {
        for(int i = 0; i < k ; i ++)
        {
            if(mask & (1 << i))
            {
                for(int j = 0; j < k; j ++)
                {
                    assert(mask - (1 << j) <= stateNo);
                    if(i != j  &&  (mask & (1 << j))  &&  dp[i][mask - (1 << j)] != INF)
                    {
                        assert(mask - (1 << j) >= 1);
                        assert(j < k  &&  i < k  &&  friends[j] < n);
                        dp[j][mask] = min(dp[j][mask], dp[i][mask - (1 << j)] + dist[i][friends[j]]);
                    }
                }
            }
        }
    }

    int node = 0;
    while(node < k)
    {
        assert(node < k  &&  stateNo < MAX_STATE);
        minimumDist = min(minimumDist, dp[node][stateNo] + dist[node][n - 1]);
        node ++;
    }

    out << minimumDist << "\n";
}

void printDp()
{
    int stateNo, initialMask;
    stateNo = (1 << k) - 1;
    initialMask = 1;
    for(int i = 0; i < k ; i++)
    {
        for(int mask = initialMask; mask <= stateNo; mask ++)
            out <<dp[i][mask] << " ";
        out << "\n";
    }
    out << "\n";
}

void printDist()
{
    for(int i = 0; i < k ; i++)
    {
        for(int j = 0; j < k; j ++)
            out <<dist[i][j] << " ";
        out << "\n";
    }
    out << "\n";
}


int main()
{
    readGraph();

//    printGraph();

    solve();

//    printDist();
//
//    printDp();
    return 0;
}