CSES - Datatähti 2022 alku - Results
Submission details
Task:Tietoverkko
Sender:Totska
Submission time:2021-10-11 17:23:33 +0300
Language:PyPy3
Status:READY
Result:0
Feedback
groupverdictscore
#10
#20
#30
Test results
testverdicttimegroup
#10.18 s1, 2, 3details
#2--2, 3details
#3--3details

Code

from operator import itemgetter
from math import factorial as fc
from copy import deepcopy
n = int(input())

edges = []
z = set()
for i in range(n-1):
    edges.append(tuple(map(int, input().split())))
    z.add(edges[i][2])

edges.sort(key=itemgetter(2))

speeds = {x: [] for x in z}

complist = [-1 for i in range(n+1)]

for e in edges:
    speeds[e[2]].append((e[0], e[1]))

# print(edges)
# print(speeds)


def dfs(node):
    if node in visited:
        return

    comp.add(node)
    visited.add(node)

    for nb in adj[node]:
        dfs(nb)


def gendict(edges):
    nodes = {e[0] for e in edges} | {e[1] for e in edges}

    adj = {e: [] for e in nodes}
    for e in edges:
        adj[e[0]].append(e[1])
        adj[e[1]].append(e[0])

    return adj, nodes




components = []

ans = 0
icomp = 0

for s in speeds:
    newcomponents = []
    edges = speeds[s]

    adj, nodes = gendict(edges)

    # uudet kaaret komponentteihin
    visited = set()

    for nod in nodes:
        comp = set()
        dfs(nod)

        if len(comp) != 0:
            newcomponents.append((comp, s))

            for node in comp:
                complist[node] = icomp

            icomp += 1


    # saman nopeuden komponenttien parit, n choose k
    for c in newcomponents:
        ans += s * (fc(len(c[0])) // (fc(2) * fc(len(c[0]) - 2)))


    def findconnected(start, cur, minspeed, notvisited):
        global ans
        minspeed = min(start[1], cur[1])
        notvisited.remove(cur)

        if len(cur[0] & start[0]) == 1:
            ans += minspeed * (len(start[0]) - 1) * (len(cur[0]) - 1)

        # notvisited = [(comp, speed)]
        for comp in notvisited:
            findconnected(start, comp, minspeed, notvisited)

        return

    #
    if len(components) != 0:
        for nc in newcomponents:
            oldcomp = deepcopy(components)
            findconnected(nc, components[0], nc[1], components)
            components = oldcomp

    components += newcomponents

print(ans)

Test details

Test 1

Group: 1, 2, 3

Verdict:

input
100
1 2 74
1 3 100
2 4 50
3 5 40
...

correct output
88687

user output
10794

Test 2

Group: 2, 3

Verdict:

input
5000
1 2 613084013
1 3 832364259
2 4 411999902
3 5 989696303
...

correct output
1103702320243776

user output
(empty)

Test 3

Group: 3

Verdict:

input
200000
1 2 613084013
1 3 832364259
2 4 411999902
3 5 989696303
...

correct output
1080549209850010931

user output
(empty)