CSES - Datatähti 2025 alku - Results
Submission details
Task:Kortit II
Sender:EmuBird
Submission time:2024-11-06 13:10:22 +0200
Language:Rust (2021)
Status:READY
Result:62
Feedback
groupverdictscore
#1ACCEPTED3
#2ACCEPTED5
#3ACCEPTED26
#4ACCEPTED28
#50
Test results
testverdicttimegroup
#1ACCEPTED0.00 s1, 2, 3, 4, 5details
#2ACCEPTED0.00 s2, 3, 4, 5details
#3ACCEPTED0.00 s3, 4, 5details
#4ACCEPTED0.01 s4, 5details
#5--5details
#6--5details

Code

use std::cell::RefCell;
use std::io;
use std::rc::Rc;

const MOD: u32 = 10u32.pow(9) + 7;
const MOD64: u64 = MOD as u64;

fn main() {
    let stdin = io::stdin();

    let cache: Rc<RefCell<Cache>> = Rc::new(RefCell::new(Cache {
        point_data: vec![],
        ncr_data: vec![],
        fact_data: vec![1],
    }));

    let t: u32 = {
        let mut input: String = String::new();
        stdin.read_line(&mut input).unwrap();
        input.trim().parse().unwrap()
    };

    for _ in 0..t {
        let values: Vec<u32> = {
            let mut input: String = String::new();
            stdin.read_line(&mut input).unwrap();
            input.trim().split_whitespace().map(|x| x.parse::<u32>().unwrap()).collect()
        };

        let answer = generate_answer(values[0], values[1], values[2], &cache);
        println!("{}", answer);
    }
}

fn generate_answer(total_cards: u32, a_points: u32, b_points: u32, cache: &Rc<RefCell<Cache>>) -> u32 {
    let total_points = a_points + b_points;

    if total_points > total_cards || total_points > 0 && (a_points >= total_points || b_points >= total_points) {
        return 0;
    }

    let draws = total_cards - a_points - b_points;

    let ncr_value = ncr(total_cards, draws, &cache);
    if ncr_value == 0 {
        0
    } else {
        let factorial_value = factorial(total_cards, &cache);
        let base_points = get_base_points(a_points, b_points, &cache);
        mod_mult(base_points, mod_mult(ncr_value, factorial_value))
    }
}

fn get_base_points(a: u32, b: u32, cache: &Rc<RefCell<Cache>>) -> u32 {
    if b > a {
        return get_base_points(b, a, cache);
    }

    {
        let c = (*cache).borrow();
        if c.point_data.len() > a as usize && c.point_data[a as usize].len() > b as usize {
            return c.point_data[a as usize][b as usize];
        }
    }
    // If this part is reached, the base points could not be read from cache.

    /// Ensures that after calling this it is safe to get `point_data[a]` and that pushing to that the row has `b - 1` elements.
    /// In other words, `point_data[a].push(something)` will result in `point_data[a][b] = something`.
    /// However, if there is already data at `point_data[a][b]`, nothing will be done.
    fn ensure_prior_values(a: usize, b: usize, cache: &Rc<RefCell<Cache>>) {
        let rows = (*cache).borrow().point_data.len();
        if rows <= a {
            let mut cache = (*cache).borrow_mut();
            for _ in rows..=a {
                cache.point_data.push(vec![]);
            }
        }

        let cols = (*cache).borrow().point_data[a].len();
        if cols < b {
            for i in cols..b {
                get_base_points(a as u32, i as u32, &cache); // this will also write it to cache
            }
        }
    }

    let hardcoded: Option<u32> =
        if b == 0 {
            if a == 0 {
                Some(1)
            } else {
                Some(0)
            }
        } else if b == 1 {
            Some(1)
        } else {
            None
        };
    if hardcoded.is_some() {
        ensure_prior_values(a as usize, b as usize, &cache);
        (*cache).borrow_mut().point_data[a as usize].push(hardcoded.unwrap());
        return hardcoded.unwrap();
    }

    let n = a + b;
    let vertical_sum = {
        let mut sum = 0;
        let b_iter = b - 1;
        for a_iter in (1..a).chain([a + 1]) {
            let points = mod_mult(get_base_points(a_iter, b_iter, &cache), ncr(n, n - a_iter - b_iter, &cache));
            sum = mod_add(sum, points);
        }
        sum
    };

    let horizontal_sum = {
        let mut sum = 0;
        for b_iter in 1..(b - 1) {
            let points = mod_mult(get_base_points(a, b_iter, &cache), ncr(n, n - a - b_iter, &cache));
            sum = mod_add(sum, points);
        }
        sum
    };

    let result = mod_sub(vertical_sum, horizontal_sum);
    ensure_prior_values(a as usize, b as usize, &cache);
    (*cache).borrow_mut().point_data[a as usize].push(result);
    result
}

fn ncr(n: u32, k: u32, cache: &Rc<RefCell<Cache>>) -> u32 {
    if k > n / 2 {
        return ncr(n, n - k, cache);
    } else if k == 0 {
        return 1;
    }

    if (*cache).borrow().ncr_data.len() > n as usize - 2 && (*cache).borrow().ncr_data[n as usize - 2].len() > k as usize - 1 && (*cache).borrow().ncr_data[n as usize - 2][k as usize - 1].is_some() {
        return (*cache).borrow().ncr_data[n as usize - 2][k as usize - 1].unwrap();
    }

    let result = mod_add(ncr(n - 1, k - 1, cache), ncr(n - 1, k, cache));
    let c = &mut (*cache).borrow_mut().ncr_data;
    if c.len() == n as usize - 2 {
        c.push(vec![]);
    }
    for _ in c[n as usize - 2].len()..=(k as usize) {
        c[n as usize - 2].push(None);
    }
    c[n as usize - 2][k as usize - 1] = Some(result);
    result
}

fn factorial(n: u32, cache: &Rc<RefCell<Cache>>) -> u32 {
    if (*cache).borrow().fact_data.len() > n as usize {
        (*cache).borrow().fact_data[n as usize]
    } else {
        let fact = mod_mult(factorial(n - 1, cache), n);
        (*cache).borrow_mut().fact_data.push(fact);
        fact
    }
}

struct Cache {
    point_data: Vec<Vec<u32>>,
    ncr_data: Vec<Vec<Option<u32>>>,
    fact_data: Vec<u32>,
}

fn mod_mult(a: u32, b: u32) -> u32 {
    (((a % MOD) as u64 * (b % MOD) as u64) % MOD64) as u32
}

fn mod_add(a: u32, b: u32) -> u32 {
    (((a % MOD) as u64 + (b % MOD) as u64) % MOD64) as u32
}

fn mod_sub(a: u32, b: u32) -> u32 {
    let a = a as u64 % MOD64;
    let b = b as u64 % MOD64;
    if a >= b {
        ((a - b) % MOD64) as u32
    } else {
        ((a + MOD64 - b) % MOD64) as u32
    }
}

Test details

Test 1

Group: 1, 2, 3, 4, 5

Verdict: ACCEPTED

input
54
4 4 0
3 1 3
3 2 2
4 0 4
...

correct output
0
0
0
0
0
...

user output
0
0
0
0
0
...

Test 2

Group: 2, 3, 4, 5

Verdict: ACCEPTED

input
284
6 1 0
5 0 2
7 1 5
7 7 5
...

correct output
0
0
35280
0
36720
...

user output
0
0
35280
0
36720
...

Test 3

Group: 3, 4, 5

Verdict: ACCEPTED

input
841
19 3 12
19 19 13
19 7 13
20 11 15
...

correct output
40291066
0
0
0
0
...

user output
40291066
0
0
0
0
...

Test 4

Group: 4, 5

Verdict: ACCEPTED

input
1000
15 12 6
7 1 6
44 4 26
6 6 5
...

correct output
0
5040
494558320
0
340694548
...

user output
0
5040
494558320
0
340694548
...

Test 5

Group: 5

Verdict:

input
1000
892 638 599
966 429 655
1353 576 1140
1403 381 910
...

correct output
0
0
0
249098285
0
...

user output
(empty)

Test 6

Group: 5

Verdict:

input
1000
2000 1107 508
2000 1372 249
2000 588 65
2000 1739 78
...

correct output
750840601
678722180
744501884
159164549
868115056
...

user output
(empty)