mod submission;
use std::{
sync::{Arc, RwLock},
thread,
time::Duration,
};
use submission::*;
const THREADS: usize = 8;
pub fn main() {
let (data, output_data) = read_training_data();
let neural_net: Arc<RwLock<NeuralNetwork<INPUT_COUNT, OUTPUT_COUNT, 18, 0>>> =
Arc::new(RwLock::new(get_neural_network()));
for thread in 0..THREADS {
let output = output_data.clone();
let input = data.clone();
let nn = neural_net.clone();
thread::spawn(move || {
let mut last_save_loss = f64::MAX;
loop {
let mut last_loss = 0.;
let mut loss = 0.;
for _ in 0..40 {
let mut changes = NeuralNetwork::new();
loss = 0.;
//let thread_batch_size = input.len() / THREADS;
for i in 0..input.len() {
//let i = (pseudo_rand() * input.len() as f64).floor() as usize;
loss += nn.read().unwrap().train(
&mut changes,
&input[i],
&output[i],
input.len(),
);
}
nn.write().unwrap().apply_changes(&changes, input.len());
println!(
"Thread: {}, Loss: {}, Change {}",
thread,
loss / input.len() as f64,
(last_loss / loss)
);
last_loss = loss;
}
if thread == 0 {
if last_save_loss * 0.999 > loss {
std::fs::write(
format!("./checkpoints/checkpoint-{:.10}.ron", loss),
format!("{:#?}", nn.read().unwrap()),
)
.expect("Can't save network");
last_save_loss = loss;
println!("Checkpoint created!");
}
std::fs::write("./neuralnet.ron", format!("{:#?}", nn.read().unwrap()))
.expect("Can't save network");
println!("saved!");
}
}
});
}
loop {
thread::sleep(Duration::from_secs(100));
}
}
// TRAIN ERROR DETECTION MODEL
pub fn main2() {}
// MODEL EXTENTION STUFF
const EXTENTION_RANDOMNESS: f64 = 1.;
pub fn main_ext() {
let extended: NeuralNetwork<INPUT_COUNT, OUTPUT_COUNT, 24, 0> = extend_network(NeuralNetwork {
input_layer: Layer {
weights: [
[
0.016365807084969798,
0.029897443645720924,
-0.09905563364983566,
0.06512664253557854,
-0.063178765477455,
0.020710813433538013,
0.09303970250163338,
-0.02275117349953524,
0.04953180021655051,
-0.06677368151043744,
0.004894784325966019,
-0.03065278408558707,
0.04068372499251621,
-0.07632304198682839,
-0.005478221953315343,
0.07659918291633323,
-0.053830964271285746,
0.02709068626942456,
-0.10410122662023238,
0.057646589214514564,
-0.07330684859940739,
0.007844677810847711,
0.07736899740747706,
-0.041076763144369574,
],
[
-0.0728924,
0.00979859999999999,
-0.026329,
0.04454899999999999,
-0.07276,
-0.0018819999999999948,
0.08080899999999999,
-0.048312999999999995,
0.02256499999999999,
-0.094744,
0.05731520000000001,
-0.05999380000000001,
0.01088420000000001,
0.0935752,
-0.0355468,
0.04714419999999999,
-0.0819778,
0.0007131999999999916,
-0.047227399999999996,
0.0354636,
-0.0936584,
-0.010967400000000006,
0.059910599999999994,
-0.0573984,
],
[
-0.0920164,
0.06004279999999999,
-0.06907920000000001,
0.013611799999999997,
0.0963028,
-0.03281920000000001,
0.0498718,
-0.0792502,
0.0034407999999999995,
-0.0444998,
0.026378199999999997,
-0.0909308,
-0.020052800000000003,
0.0626382,
-0.06648380000000001,
0.016207199999999998,
0.0870852,
-0.030223800000000002,
-0.0781646,
0.004526399999999997,
0.0754044,
-0.0419046,
0.028973399999999993,
-0.0883356,
],
[
0.05814753284170782,
0.022019625628913065,
0.09291114524596923,
-0.024360965009895385,
0.046576298820485904,
-0.07064855977605915,
0.0003442923707864423,
0.08315647477343562,
-0.04586000475991252,
-0.08191198188308406,
-0.01097617988375457,
0.07175877053964407,
-0.05732611781032156,
0.02538895658859303,
-0.1037287887218909,
-0.021076726077198935,
0.049704075375251885,
0.0016327966849441656,
0.08419311619793376,
-0.045040977347176256,
0.03755813645984853,
-0.0916341992547326,
-0.009000727202543574,
0.06182809521468674,
],
[
-0.0845224,
-0.013644400000000001,
0.0690466,
-0.0600754,
0.022615600000000003,
0.0934936,
-0.023815399999999997,
0.047062599999999996,
0.010935000000000005,
0.081813,
-0.03549600000000001,
0.035382000000000004,
-0.081927,
-0.011048999999999998,
0.07164200000000001,
-0.057479999999999996,
-0.09360779999999999,
-0.0227298,
0.059961200000000006,
-0.0691608,
0.013530200000000003,
0.0844082,
-0.032900799999999994,
0.0379772,
],
[
0.08454060000000001,
-0.04458140000000001,
0.03810960000000001,
-0.09101240000000001,
-0.008321400000000001,
0.0625566,
-0.0665654,
0.016125600000000007,
-0.0318152,
0.050875800000000006,
-0.0782462,
0.004444800000000004,
0.0753228,
-0.041986199999999994,
0.028891800000000002,
-0.0884172,
0.06364220000000001,
-0.053666799999999994,
0.017211200000000003,
0.09990220000000001,
-0.029219799999999997,
0.05347120000000001,
-0.07565079999999999,
0.007040200000000008,
],
[
0.041790400000000005,
-0.0755186,
-0.004640599999999995,
0.0780504,
-0.051071599999999995,
0.031619400000000006,
-0.0975026,
0.06636979999999999,
-0.06275220000000001,
0.01993879999999999,
0.0908168,
-0.0264922,
0.04438580000000001,
-0.0729232,
-0.0020452000000000027,
-0.038173,
0.032705000000000005,
-0.084604,
-0.013726000000000006,
0.057152,
-0.060156999999999995,
0.010721000000000003,
0.09341200000000001,
-0.03571,
],
[
0.021881923817085223,
0.0918126470571436,
-0.026232508640812918,
0.044781804796862355,
-0.07233746336412805,
-0.0024859206428952665,
0.07983917386363963,
0.03287112133509152,
-0.0823357424927479,
-0.010949659591082109,
0.07261769878154778,
-0.056106219780217584,
0.02705228545699406,
0.09747492779352068,
-0.021731386521725345,
-0.0725493652479984,
0.008829704163125668,
0.07952768013057757,
-0.03914040352252662,
0.029629967522289514,
-0.09208511158402441,
-0.028596120749887797,
0.04334748364583672,
-0.018722391194185176,
],
[
-0.1141335732802929,
-0.02217967008433491,
-0.13957899387935324,
-0.06369607171990145,
0.03776022011790723,
-0.0158651395288838,
0.002021670337776273,
0.13163616974121367,
0.24004052022292313,
0.14378283775737707,
0.22206453622295347,
0.32138829010222797,
0.21182148622856187,
0.3180598973011594,
0.3088127346202453,
0.42083778788654014,
0.2843404032218277,
0.3159252145324529,
0.15139481580682,
0.20737345575328897,
0.2567586555405705,
0.13964278359327084,
0.21377541348429444,
0.18236318389323095,
],
[
-0.06256561132330053,
0.020110386937770633,
0.09096908320229875,
-0.026382161730681742,
0.04443830075750825,
-0.07289775803374597,
0.07913099117411775,
-0.03820256266096591,
0.0326401057309401,
-0.08466108083974908,
-0.013780519866985974,
0.06890310797294548,
-0.06021770304600942,
0.022470876438763374,
-0.025450631917286293,
0.05726023213667222,
-0.07185139404528157,
-0.0009566508760441469,
0.08177967773331704,
-0.047290145609991126,
0.03548055524741869,
-0.09356298745622693,
0.07039512531033765,
-0.05863674177293379,
],
[
-0.09377100000000001,
-0.022892999999999997,
0.059798,
-0.06932400000000001,
0.013366999999999995,
-0.0345736,
0.0481174,
-0.08100460000000001,
0.001686399999999999,
0.0725644,
-0.0447446,
0.026133399999999994,
-0.0911756,
0.060883599999999996,
-0.0564254,
0.014452599999999993,
0.0971436,
-0.0319784,
0.0507126,
-0.0784094,
0.0042815999999999965,
0.0751596,
0.039032,
-0.09009,
],
[
0.18447984796444766,
0.050549907796518516,
0.13205348766401642,
0.20737267038571103,
0.08239263920834507,
0.01620519959373799,
0.06471884504335294,
0.12875573768574444,
-0.01774713897776613,
0.05573876569387208,
-0.07863401296420879,
-0.0026232889532777606,
0.06092267946041859,
0.017402989377819478,
-0.12062206742099178,
-0.040908468838196946,
0.03919650936200455,
-0.056749837266459775,
0.03149338137717491,
0.12947421356280683,
0.01704480238596738,
-0.004805092364709424,
0.08392569579208306,
0.1893154138505638,
],
[
0.03252799118593452,
-0.08478094183585033,
-0.013903065472896374,
0.06878787589818655,
0.020847611647330858,
-0.09646129909926573,
-0.025582942998698798,
0.05710839236668758,
-0.07201304759744981,
0.010678213523672975,
0.08155609683975408,
-0.03575286792687527,
-0.0836935462296861,
-0.012815449135374317,
0.06987588118782456,
-0.059246586624801416,
0.023443351534932915,
0.09432100008063003,
-0.02298824561907487,
0.0478897087714869,
0.011761920223817971,
0.08263993297709106,
-0.034669120714227085,
0.036208912986485986,
],
[
-0.2073553931342036,
-0.07278781733621717,
-0.11072716348490406,
0.004201603394498865,
0.004519232764451153,
0.16954648322976668,
0.11994381147775385,
0.23027012356416166,
0.29128063277017113,
0.17695268726601812,
0.2371499475361844,
0.10849485133127822,
0.26394542841429347,
0.18309276223614845,
0.35845868963749017,
0.3251631622526473,
0.35117848808377,
0.31939136020818365,
0.19088342364754116,
0.31675185880090995,
0.49798300772021264,
0.9814340243442788,
1.511350973485526,
2.511766954793907,
],
[
-0.0297368421092453,
0.041158460493318436,
-0.08794813442767176,
0.0759365508129182,
-0.05317766642809878,
0.029501194806701193,
-0.09962162604917742,
-0.01693738128712482,
0.05395468706520464,
-0.06335721108369906,
0.007501175739437978,
-0.028627284251466587,
0.042238777168887305,
-0.07507865221336169,
-0.004226046831466281,
0.07843453220836769,
-0.050708238296224666,
0.031953459257837415,
-0.09718745708073964,
0.06667791872887656,
-0.062448682798413285,
0.02021592719676636,
0.09104682579625827,
-0.02629246814799321,
],
[
-0.07288524035334594,
0.010072897491075445,
0.08126350412408932,
0.04542262876814589,
-0.08265900688265507,
0.0016932353756438335,
0.07411745647210997,
-0.04212975267876837,
0.02948138074415119,
0.1009561950219722,
-0.015958264256311125,
-0.06343789047471418,
0.019598774620825883,
0.09092181269856064,
-0.02568458554089945,
0.045767911339126464,
-0.07181504959329367,
-0.002223365660369155,
0.07915304391230038,
0.029993375870027282,
-0.08803197169081955,
-0.016949443103326044,
0.06651051605863317,
-0.06113016514838245,
],
[
0.08219739115585713,
-0.04308849451280582,
-0.07776963295951832,
-0.009857790797145776,
0.08186280269574497,
-0.03172347200944647,
0.07091734986784443,
-0.03945167282146533,
0.0635682836839723,
0.14563773739992653,
0.1184513086275251,
0.19889467990871185,
0.09123369272619673,
0.1706257663282062,
0.06100676886280245,
0.13199448572618303,
0.20881239967140322,
0.06779768251437923,
0.019910709800178478,
0.0770959455370418,
0.13959382423807776,
-0.01403565872619397,
0.03218566151801578,
-0.14744355289974156,
],
[
0.04571709499872495,
0.010160925086839613,
0.08087156413783447,
-0.03682612199789844,
0.03516963574308468,
-0.07916547572751853,
-0.003104755336139809,
0.08575544897423136,
-0.03742864743614274,
0.049174582769141786,
0.0038838492249906632,
0.08898628280945192,
-0.03813692099645037,
0.04636464858752955,
0.1180344592374395,
-0.001030289766113078,
0.06451851777947083,
-0.05933540410359508,
0.08847163499952204,
-0.03139670880960187,
0.03782262246468793,
-0.08004483854472333,
-0.009540258779623558,
0.07252362578133767,
],
],
biases: [
0.031452782348287814,
0.013479599999999991,
-0.017457599999999997,
-0.059063216360768916,
-0.0911448,
-0.040900599999999995,
-0.0718376,
0.1073884705321795,
-0.00932460799253341,
0.025024273605055485,
-0.019211999999999996,
0.030858780354485995,
-0.08110263308587991,
-0.7012227963759011,
0.043252418722894904,
0.015426117922068616,
0.06631715448522192,
-0.051785310469542385,
],
},
hidden_layers: [],
output_layer: Layer {
weights: [
[
-0.05703802526749842,
-0.011203600000000003,
0.059674399999999996,
-0.023500462594586415,
0.0250564,
0.0959344,
-0.0213746,
-0.07339371830505204,
-0.11837684512070612,
0.08672360234492896,
-0.03305519999999999,
0.08630340871104916,
-0.07943438331041205,
0.18181386366828328,
0.07705193407733937,
0.0025142231890497736,
-0.18729989186287485,
0.012506093498565855,
],
[
-0.031753477751746244,
0.02778400000000001,
-0.089525,
-0.014881949882189124,
-0.05477460000000001,
0.01610339999999999,
0.09879439999999999,
-0.0229873830932048,
-0.12375892515765144,
-0.07228331611646031,
0.005932399999999993,
0.10214470737273329,
0.04070916546711668,
0.1797685814280238,
-0.0015358924970939858,
0.04017474344742837,
-0.17859552865201775,
0.001387494730463138,
],
[
0.01920305159520628,
-0.052046999999999996,
0.018830999999999997,
-0.07255468756357401,
-0.027600000000000003,
0.055091,
-0.074031,
0.005032633378341296,
-0.11773417946100548,
0.04732698560311076,
-0.0857118,
0.07985243089928101,
0.06790294068066849,
0.18060139305022435,
0.03720171071655352,
-0.1289642658627665,
-0.147651663833331,
-0.05844859933345004,
],
[
-0.06129041097281596,
-0.024872399999999996,
0.05781860000000001,
-0.037580253780329365,
0.011387600000000008,
0.08226560000000001,
-0.0350434,
-0.09520926686694828,
-0.16909667140011037,
0.07362754625837097,
-0.04672420000000001,
0.08781944322424497,
-0.09308617617244912,
0.17946253561302594,
0.06490376837812918,
-0.011105567669666154,
-0.027287940133310113,
-0.016821907441529508,
],
[
-0.0425567175462357,
0.014115199999999994,
0.08499319999999999,
-3.631702416102432e-5,
-0.0802566,
0.002434399999999992,
0.07331240000000001,
-0.05329641845373612,
-0.084577424668251,
-0.08746673893236924,
-0.0195496,
0.09152032936597139,
0.0270887075550665,
0.15762589559398046,
-0.014612380762211809,
0.026138525816254244,
-0.13763223866296204,
0.0156790597718171,
],
[
-0.013236277291470188,
-0.077529,
0.005162,
0.08081857290711349,
-0.041269,
0.029608999999999996,
-0.0877,
-0.004133510511923142,
-0.03126421193487638,
0.03415501442447006,
-0.0993806,
0.014125518412704485,
0.054237166970245804,
0.14722748571902225,
0.01269668863102259,
0.0741281503888547,
-0.12193731979074944,
-0.10066322751979492,
],
[
0.09629643493485918,
-0.0385414,
0.03233660000000001,
-0.05940247570441866,
-0.014094399999999997,
0.06859660000000001,
0.020656000000000008,
-0.11846511460566495,
-0.04374438662791898,
0.05879828411095856,
-0.06039300000000001,
-0.0009725010625753716,
0.09325908995648328,
0.12946159405101365,
-0.06644072163542755,
-0.016071352896887967,
0.025737380357193397,
-0.036592269244843094,
],
[
-0.0716605783887,
-0.0113668,
0.07132419999999999,
0.039227993696773315,
-0.0939254,
-0.0112344,
0.0596436,
-0.06888438064183894,
0.052653480176269,
0.09737368882999387,
-0.0332184,
-0.08109643223608066,
0.0015198134279518886,
0.10848216010067242,
-0.03921600011050003,
0.0370939521274188,
-0.0793951508696832,
0.026148745848090012,
],
[
-0.05784469705460472,
-0.0911978,
-0.0203198,
0.04314522651781204,
-0.0667508,
0.015940200000000005,
0.0986312,
-0.05012790002432129,
0.1011235576277164,
0.005210271552491358,
0.08695040000000001,
-0.10734890890977591,
0.040411007643487894,
0.08667700659294068,
0.003490331585907222,
0.050552568192751106,
-0.08956296860469341,
0.08460347832138619,
],
[
0.06130318143340829,
-0.0640232,
0.01866779999999999,
-0.10528751801100177,
-0.0277632,
0.05492779999999999,
0.00698700000000001,
0.05921636390269124,
0.07459058563894118,
0.04236483247814882,
-0.085875,
-0.10633127459407658,
0.06755202417287051,
0.07687365421248915,
-0.09155449352496135,
-0.010688660970600726,
0.06934762870717552,
-0.038809066494141495,
],
[
0.08785580999116811,
-0.0250356,
0.04584239999999999,
0.006993285367626222,
0.0924056,
-0.0367164,
0.0459746,
-0.10762083218141824,
0.07104722556282012,
0.0697344793692867,
-0.0468874,
-0.16960909754507597,
-0.012267110436804741,
0.07859718684549076,
-0.05208907992725116,
0.018023234093684833,
0.11086063223529301,
0.0017069715622517166,
],
[
0.0009267403284441151,
0.0951332,
-0.0339888,
0.0269955548943452,
-0.0804198,
0.0022712000000000066,
0.0731492,
-0.047185344900259885,
0.08665367383529304,
-0.009748430926091865,
0.061468600000000005,
-0.17642604794388536,
0.014857112794169012,
0.06902555854977288,
-0.025225915730265624,
0.06737291263559271,
0.08546101628595873,
0.10696135350502953,
],
],
biases: [
0.15564504222217423,
0.1473202333557324,
0.14150325747102618,
0.1417972607008875,
0.12915401979983357,
0.11102123566726314,
0.10943827460128827,
0.11337626442372599,
0.11272235937363072,
0.08431237434673645,
0.10892382608690006,
0.10589741994315555,
],
},
});
println!("{:#?}", extended);
}
fn extend_layer<const I1: usize, const I2: usize, const O1: usize, const O2: usize>(
layer: Layer<I1, O1>,
) -> Layer<I2, O2> {
let mut new: Layer<I2, O2> = Layer::random();
for i in 0..O2 {
if i < O1 {
new.biases[i] = layer.biases[i];
} else {
new.biases[i] /= EXTENTION_RANDOMNESS;
}
for j in 0..I2 {
if i < O1 && j < I1 {
new.weights[i][j] = layer.weights[i][j];
} else {
new.weights[i][j] /= EXTENTION_RANDOMNESS;
}
}
}
new
}
fn extend_network<
const I1: usize,
const I2: usize,
const O1: usize,
const O2: usize,
const H1: usize,
const H2: usize,
const L1: usize,
const L2: usize,
>(
network: NeuralNetwork<I1, O1, H1, L1>,
) -> NeuralNetwork<I2, O2, H2, L2> {
let old_hidden_layers = network
.hidden_layers
.map(|layer| extend_layer::<H1, H2, H1, H2>(layer));
let mut new_hidden_layers = [Layer::<H2, H2>::random(); L2];
for i in 0..L2 {
if i < L1 {
new_hidden_layers[i] = old_hidden_layers[i];
} else {
for k in 0..H2 {
new_hidden_layers[i].biases[k] /= EXTENTION_RANDOMNESS;
for j in 0..H2 {
new_hidden_layers[i].weights[k][j] /= EXTENTION_RANDOMNESS;
}
}
for j in 0..H2 {
new_hidden_layers[i].weights[j][j] = 1.;
}
}
}
NeuralNetwork {
input_layer: extend_layer::<I1, I2, H1, H2>(network.input_layer),
output_layer: extend_layer::<H1, H2, O1, O2>(network.output_layer),
hidden_layers: new_hidden_layers,
}
}