#ifndef QLEARNING_H #define QLEARNING_H #include #include #include "diceTower.h" using namespace std; #define MAX(a,b) ((a) > (b) ? (a) : (b) ) #define MIN(a,b) ((a) < (b) ? (a) : (b) ) // CONSTS const double gamma = 0.8; // discount on reward const int exploreEps = 70; // exploration constant const double stepReward = 11.0; // rewards const double lavaReward = -10.0; const double cakeReward = 10.0; const int actioneffect[4][2] = {{-1,0},{0,1},{1,0},{0,-1}}; const char actionchar[4] = {'^', '>', 'v', '<'}; // GLOBALS diceTower roll; double QTable[4][3][4]; int count[4][3][4]; //---------------------------------------------------- void initGlobals() { for(int k=0; k < 4; k++) for (int j=0; j < 3; j++) for ( int l=0; l < 4; l++){ QTable[k][j][l] = 0; count[k][j][l] = 0; } } //---------------------------------------------------- bool isWall(int r, int c) { if (r < 0 || r > 3) return true; if (c < 0 || c > 2) return true; if (r == 2 && c == 1) return true; return false; } //---------------------------------------------------- int getPolicy(int r, int c) { double v = -99999.0; int a = 0; for(int k=0; k < 4; k++){ //cerr << QTable[r][c][k] << endl; if ( QTable[r][c][k] > v ) { v = QTable[r][c][k]; a = k; } } return a; } //---------------------------------------------------- double getValue(int r, int c) { double v = -99999.0; for(int k=0; k < 4; k++){ //cerr << QTable[r][c][k] << endl; v = MAX( v, QTable[r][c][k] ); } return v; } //---------------------------------------------------- int getSCount(int r, int c) { int x = 0; for(int k=0; k < 4; k++){ //cerr << count[r][c][k] << endl; x += count[r][c][k]; } return x; } //---------------------------------------------------- void printBoards() { for(int k=0; k < 4; k++){ // for row; for (int j=0; j < 3; j++){ // for column SCount if ( isWall(k,j) ) cout << setw(6) << "X"; else cout << setw(6) << getSCount(k,j); cout << "|"; } cout << " "; for (int j=0; j < 3; j++){ // for column Value if ( isWall(k,j) ) cout << setw(6) << "X"; else cout << setw(6) << setprecision(4) << getValue(k,j); cout << "|"; } cout << " "; for (int j=0; j < 3; j++){ // for column Policy if ( isWall(k,j) ) cout << setw(6) << "X"; else cout << setw(6) << actionchar[ getPolicy(k,j) ]; cout << "|"; } cout << endl; } } //---------------------------------------------------- bool exec(int sr, int sc, int a, double &r, int &srp, int &scp) { // Case: at Cake if (sr == 0 && sc == 2){ srp = 3; scp = 0; r = cakeReward; return true; } // Case : at Lava if (sr == 0 && sc == 1){ srp = 3; scp = 0; r = lavaReward; return true; } // Case : other square // will action work as intended? 80% chance it will not. // [1..10] drift; [11..89] normal; [90..100]drift the other way int p = roll.d(100); if ( p <= 10 ) { a = (a + 4 -1) % 4; /*cerr << "skid!" << endl;*/ } if ( p >= 90 ) { a = (a + 4 +1) % 4; /*cerr << "skid!" << endl;*/ } // compute consequence srp = sr + actioneffect[a][0]; scp = sc + actioneffect[a][1]; // If move is to an invalid place, stay in place. if( isWall( srp, scp ) ){ srp = sr; scp = sc; } r = stepReward; return false; } //---------------------------------------------------- int selectAction( int r, int c ) { int a = getPolicy(r, c); int p = roll.d(100); if ( p <= exploreEps ) return a; a = (a + roll.d(3)) %4; return a; } //---------------------------------------------------- double getTDalpha(int sr, int sc, int a) { double w = 5.0; return (w + 1)/(w + count[sr][sc][a] ); } void doEpisode() { // Init start state int sr = 3; int sc = 0; bool done = false; while (!done) { int a = selectAction(sr,sc); double rew; int srp, scp; done = exec(sr, sc, a, rew, srp, scp); count[sr][sc][a]++; QTable[sr][sc][a] += getTDalpha(sr,sc,a) * ( rew + gamma*getValue(srp, scp) - QTable[sr][sc][a]); sr = srp; sc = scp; } } #endif