CSES - Datatähti 2024 alku - Results
Submission details
Task:Uolevin kalansaalis
Sender:Tipu
Submission time:2023-11-04 18:13:05 +0200
Language:C++ (C++11)
Status:COMPILE ERROR

Compiler report

input/code.cpp: In function 'int main()':
input/code.cpp:309:5: error: 'assert' was not declared in this scope
  309 |     assert(p > os);
      |     ^~~~~~
input/code.cpp:7:1: note: 'assert' is defined in header '<cassert>'; did you forget to '#include <cassert>'?
    6 | #include <unordered_map>
  +++ |+#include <cassert>
    7 | using namespace std;

Code

#include<iostream>
#include <deque>
#include <vector>
#include <algorithm>
#include <cmath>
#include <unordered_map>
using namespace std;
typedef long long ll;


struct Hex {       
  public:            
    int l = 0;   
    int r = 0;
    int x = 0;
};

const int SIZE = 501;

int n, m, k;
int v[SIZE][SIZE];
Hex d[SIZE][SIZE];
unordered_map<int, pair<bool, bool>> sm[SIZE][SIZE];

int count_l_colum(int i, int j, int l){
    bool add = (i % 2 == 0);
    if(i-l < 0 || i >= n + 1 || j - int(ceil((l-add)/2.0)) < 0 || j >= m + 1){
        return 0;
    }
    return (d[i][j].l - d[i-l][j - int(ceil((l-add)/2.0))].l);
} 

int count_r_colum(int i, int j, int l){
    bool add = (i % 2 == 0);
    if(i-l < 0 || i >= n + 1 || j < 0 || j + int((l+add)/2.0) >= m + 1){
        return 0;
    }
    return (d[i][j].r - d[i-l][j + int((l+add)/2.0)].r);
} 

int count_x_colum(int i, int j, int l){
    if(i < 0 || i >= n + 1 || j-l < 0 || j >= m + 1){
        return 0;
    }
    return (d[i][j].x - d[i][j-l].x);
} 
int scan(int p){

    pair<short, short> a, b;

    int safe_sum = 0;
    for (int i = n; i >= 1; i--) {
        for (int j = 1; j <= m; j++)
        {
            a = {i, j};
            b = {i, j};
            int s = count_x_colum(i, j, 1);
            int l = 1;
            bool up = true;
            int safety = 0;
            while(safety < 1e6){
                safety++;
                auto it = sm[a.first][a.second].find(l);

                if(l == 1) up = true;

                //tsekataan onko muistissa
                if(it != sm[a.first][a.second].end()){
                    if((up && it->second.first) || (!up && it->second.second)){
                        break;
                    }
                    else{
                        if(up) it->second.first = true;
                        else it->second.second = true;
                    }
                //lisätään muistikohta
                } else{
                    if (up) sm[a.first][a.second].insert({l, {true, false}});
                    else sm[a.first][a.second].insert({l, {false, true}});
                }


                bool add = (b.first % 2 == 0);
                if(l == 1){

                    int su1 = count_l_colum(b.first, b.second + 1, l + 1);
                    int su2 = count_x_colum(b.first + 1, b.second + add, l + 1);
                    int su3 = count_r_colum(a.first, a.second - 1, l + 1);

                    int su = min({su1, su2, su3});

                    int sd1 = count_l_colum(b.first + 1, b.second + (add - 1), l + 1);
                    int sd2 = count_x_colum(b.first - l, b.second + int((l+add)/2.0), l + 1);
                    int sd3 = count_r_colum(b.first + 1, b.second + add, l + 1);

                    int sd = min({sd1, sd2, sd3});

                    if(su <= sd && su < 0){
                        up = true;
                        l += 1;
                        if(su1 <= su2 && su1 <= su3){
                            s+=su1;
                            b = {b.first, b.second + 1};
                            continue;
                        }
                        if(su2 <= su3 && su2 <= su1){
                            s += su2;
                            b = {b.first + 1, b.second + add};
                            a = {a.first + 1, a.second + (add - 1)};
                            continue;
                        }
                        if(su3 <= su1 && su3 <= su2){
                            s += su3;
                            a = {a.first, a.second - 1};
                            continue;
                        }
                    }
                    else if (sd < 0){
                        up = false;
                        l += 1;
                        if(sd1 <= sd2 && sd1 <= sd3){
                            s+=sd1;
                            a = {a.first, a.second - 1};
                            b = {b.first + 1, b.second + (add - 1)};
                            continue;
                        }
                        if(sd2 <= sd3 && sd2 <= sd1){
                            s += sd2;
                            a = {a.first - 1, a.second + (add - 1)};
                            continue;
                        }
                        if(sd3 <= sd1 && sd3 <= sd2){
                            s += sd3;
                            b = {b.first + 1, b.second + add};
                            continue;
                        }
                    }
                    else{
                        break;
                    }
                }
                if(l != 0){
                    if(up){
                        int su1 = count_l_colum(b.first, b.second + 1, l + 1);
                        int su2 = count_x_colum(b.first + 1, b.second + add, l + 1);
                        int su3 = count_r_colum(a.first, a.second - 1, l + 1);

                        int su = min({su1, su2, su3});

                        int ru1 = -1 * count_l_colum(b.first, b.second, l);
                        int ru2 = -1 * count_x_colum(b.first, b.second, l);
                        int ru3 = -1 * count_r_colum(a.first, a.second , l);

                        int ru = min({ru1, ru2, ru3});

                        if(su <= ru && su < 0){
                            l += 1;
                            if(su1 <= su2 && su1 <= su3){
                                s+=su1;
                                b = {b.first, b.second + 1};
                                continue;
                            }
                            if(su2 <= su3 && su2 <= su1){
                                s += su2;
                                b = {b.first + 1, b.second + add};
                                a = {a.first + 1, a.second + (add - 1)};
                                continue;
                            }
                            if(su3 <= su1 && su3 <= su2){
                                s += su3;
                                a = {a.first, a.second - 1};
                                continue;
                            }
                        }
                        else if(ru < 0){
                            l -= 1;
                            if(ru1 <= ru2 && ru1 <= ru3){
                                s += ru1;
                                b = {b.first, b.second - 1};
                                continue;
                            }
                            if(ru2 <= ru3 && ru2 <= ru1){
                                s += ru2;
                                b = {b.first - 1, b.second + (add - 1)};
                                a = {a.first - 1, a.second + add};
                                continue;
                            }
                            if(ru3 <= ru1 && ru3 <= ru2){
                                s += ru3;
                                a = {a.first, a.second + 1};
                                continue;
                            }
                        }
                        else{
                            break;
                        }
                    }
                    else{
                        int sd1 = count_l_colum(b.first + 1, b.second + (add - 1), l + 1);
                        int sd2 = count_x_colum(b.first - l, b.second + int((l+add)/2.0), l + 1);
                        int sd3 = count_r_colum(b.first + 1, b.second + add, l + 1);

                        int sd = min({sd1, sd2, sd3});

                        int rd1 = -1 * count_l_colum(b.first, b.second, l);
                        int rd2 = -1 * count_x_colum(a.first, a.second + l - 1, l);
                        int rd3 = -1 * count_r_colum(b.first, b.second, l);

                        int rd = min({rd1, rd2, rd3});

                        if(sd <= rd && sd < 0){
                            l += 1;
                            if(sd1 <= sd2 && sd1 <= sd3){
                                s+=sd1;
                                a = {a.first, a.second - 1};
                                b = {b.first + 1, b.second + (add - 1)};
                                continue;
                            }
                            if(sd2 <= sd3 && sd2 <= sd1){
                                s += sd2;
                                a = {a.first - 1, b.second - int(ceil((l-add)/2.0))};
                                continue;
                            }
                            if(sd3 <= sd1 && sd3 <= sd2){
                                s += sd3;
                                b = {b.first + 1, b.second + add};
                                continue;
                            }
                        }
                        else if(rd < 0){
                            l -= 1;
                            if(rd1 <= rd2 && rd1 <= rd3){
                                s+=rd1;
                                a = {a.first, a.second + 1};
                                b = {b.first - 1, b.second + add};
                                continue;
                            }
                            if(rd2 <= rd3 && rd2 <= rd1){
                                s += rd2;
                                a = {a.first + 1, b.second - int((l-add)/2.0)};
                                continue;
                            }
                            if(rd3 <= rd1 && rd3 <= rd2){
                                s += sd3;
                                b = {b.first - 1, b.second + (add - 1)};
                                continue;
                            }
                        }
                        else{
                            break;
                        }
                    }
                }
            }
            p = min(s, p);
            // cout << "( " << i << ";" << j << ") s:"  << s << " p:" << p << " safety: " << safety << endl; 
            safe_sum += safety;
        }    
    }     
    // cout << "ss: " << safe_sum << endl;
    return p;   
}


int main() {
    // freopen("input.txt", "r", stdin);



    cin >> n >> m >> k;

    int a, b;
    char c;
    for (int i = 0; i < k; i++) {
        cin >> a >> b >> c;
        if(c == 'H'){
            v[a][b] = 1;
        }else{
            v[a][b] = -10;
        }
    }
    int os = 0;

    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= m; j++) {
            bool add = 0;
            if(i % 2 == 0){
                add = 1;
            }
            os += v[i][j];
            d[i][j].l = v[i][j] + d[i-1][j-1+add].l;
            d[i][j].r = v[i][j] + d[i-1][j+add].r;
            d[i][j].x = v[i][j] + d[i][j-1].x;
        }
    }

    // for (int i = 1; i <= n; i++) {
    //     for (int j = 1; j <= m; j++) {
    //         cout << "(" << d[i][j].l << " ; " << d[i][j].r << " ; " << d[i][j].x << ")  " ;
    //     }
    //     cout << "\n";
    // }

    int p = 1e6;
    p = scan(p);

    // cout << p << endl;
    // cout << os << endl;
    assert(p > os);
    cout << os - p;
}