CSES - Aalto Competitive Programming 2024 - wk10 - Homework - Results
Submission details
Task:Line Intersections
Sender:Farah
Submission time:2024-11-11 17:54:57 +0200
Language:C++ (C++20)
Status:COMPILE ERROR

Compiler report

input/code.cpp:45:15: error: invalid preprocessing directive #Horizontal
   45 |             # Horizontal line
      |               ^~~~~~~~~~
input/code.cpp:52:15: error: invalid preprocessing directive #Vertical
   52 |             # Vertical line
      |               ^~~~~~~~
input/code.cpp:58:15: error: invalid preprocessing directive #Should
   58 |             # Should not happen according to problem statement
      |               ^~~~~~
input/code.cpp:82:15: error: invalid preprocessing directive #add
   82 |             # add
      |               ^~~
input/code.cpp:85:15: error: invalid preprocessing directive #remove
   85 |             # remove
      |               ^~~~~~
input/code.cpp:88:15: error: invalid preprocessing directive #query
   88 |             # query
      |               ^~~~~
input/code.cpp:1:1: error: 'import' does not name a type
    1 | import sys
      | ^~~~~~
input/code.cpp:1:1: note: C++20 'import' only available with '-fmodules-ts', which is not...

Code

import sys
import sys
import sys

def main():
    import sys
    import sys
    sys.setrecursionlimit(1 << 25)
    from sys import stdin
    import sys

    class BIT:
        def __init__(self, size):
            self.N = size + 2
            self.tree = [0] * (self.N)

        def add(self, idx, delta):
            while idx < self.N:
                self.tree[idx] += delta
                idx += idx & -idx

        def sum(self, idx):
            res = 0
            while idx > 0:
                res += self.tree[idx]
                idx -= idx & -idx
            return res

    import sys

    data = sys.stdin.read().split()
    ptr = 0
    n = int(data[ptr])
    ptr += 1
    h_lines = []
    v_lines = []
    x_set = set()
    for _ in range(n):
        x1 = int(data[ptr])
        y1 = int(data[ptr +1])
        x2 = int(data[ptr +2])
        y2 = int(data[ptr +3])
        ptr +=4
        if y1 == y2:
            # Horizontal line
            if x1 > x2:
                x1, x2 = x2, x1
            h_lines.append( (y1, x1, x2) )
            x_set.add(x1)
            x_set.add(x2)
        elif x1 == x2:
            # Vertical line
            if y1 > y2:
                y1, y2 = y2, y1
            v_lines.append( (x1, y1, y2) )
            x_set.add(x1)
        else:
            # Should not happen according to problem statement
            pass

    sorted_x = sorted(x_set)
    x_to_idx = {x:i+1 for i, x in enumerate(sorted_x)}
    m = len(sorted_x)

    events = []
    for v in v_lines:
        x, y1, y2 = v
        x_c = x_to_idx[x]
        events.append( (y1, 0, x_c) )
        events.append( (y2 +1, 2, x_c) )
    for h in h_lines:
        y, x1, x2 = h
        x1_c = x_to_idx[x1]
        x2_c = x_to_idx[x2]
        events.append( (y, 1, x1_c, x2_c) )

    events.sort()
    bit = BIT(m)
    total =0
    for event in events:
        if event[1] ==0:
            # add
            bit.add(event[2],1)
        elif event[1] ==2:
            # remove
            bit.add(event[2], -1)
        elif event[1] ==1:
            # query
            x1_c, x2_c = event[2], event[3]
            cnt = bit.sum(x2_c) - bit.sum(x1_c -1)
            total += cnt
    print(total)

if __name__ == "__main__":
    main()