sample_gridworld.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from introrl.environments.env_baseline import EnvBaseline
from introrl.reward import Reward

# define layout to create output displays
row_1 = [ (0,0), (0,1),   (0,2), 'Goal' ]
row_2 = [ (1,0),'"Wall"', (1,2), 'Pit' ]
row_3 = [ 'Start', (2,1),   (2,2), (2,3) ]
s_hash_rowL=[row_1, row_2, row_3]

# add layout row and column markings (if any)
row_tickL=[ 0, 1, 2]
col_tickL=[ 0, 1, 2, 3]
x_axis_label='cols'
y_axis_label='rows'

# one way to define actions is an explicit dict of actions.
# (can also simply provide logic within a function to define actions)
actionD = {(0, 0): ('D', 'R'),
           (0, 1): ('L', 'R'),
           (0, 2): ('L', 'D', 'R'),
           (1, 0): ('U', 'D'),
           (1, 2): ('U', 'D', 'R'),
           'Start': ('U', 'R'),
           (2, 1): ('L', 'R'),
           (2, 2): ('L', 'R', 'U'),
           (2, 3): ('L', 'U')  }

# define rewards
rewardD = {'Goal': 1, 'Pit': -1}

def get_next_state( s_hash, a_desc ):
    """use layout definition to get next state"""
    
    if s_hash == 'Start':
        s_hash = (2,0)
    row,col = s_hash # all non-terminal s_hash are (row, col)
    if a_desc == 'U':
        row -= 1
    elif a_desc == 'D':
        row += 1
    elif a_desc == 'R':
        col += 1
    elif a_desc == 'L':
        col -= 1
    # no limit checking done... assume only legal moves are submitted
    return s_hash_rowL[row][col]

def get_gridworld():
    gridworld = EnvBaseline( name='Sample Grid World',s_hash_rowL=s_hash_rowL, 
                             row_tickL=row_tickL, x_axis_label=x_axis_label, 
                             col_tickL=col_tickL, y_axis_label=y_axis_label,
                             colorD={'Goal':'g', 'Pit':'r', 'Start':'b'},
                             basic_color='skyblue')
                             
    gridworld.set_info( 'Sample Grid World showing basic MDP creation.' )

    # add actions from each state 
    #   (note: a_prob will be normalized within add_action_dict)
    gridworld.add_action_dict( actionD )
    
    # for each action, define the next state and transition probability 
    # (here we use the layout definition to aid the logic)
    for s_hash, aL in actionD.items():
        for a_desc in aL:
            sn_hash = get_next_state( s_hash, a_desc )
            reward = rewardD.get( sn_hash, 0.0 )
            
            # for deterministic MDP, use t_prob=1.0
            gridworld.add_transition( s_hash, a_desc, sn_hash, 
                                      t_prob=1.0, reward_obj=reward)
    
    # after the "add" commands, send all states and actions to environment
    # (any required normalization is done here as well.)
    gridworld.define_env_states_actions()  

    # If there is a start state, define it here.
    gridworld.start_state_hash = 'Start'
    
    # If a limited number of start states are desired, define them here.
    gridworld.define_limited_start_state_list( [(2,0), (2,2)] )

    # if a default policy is desired, define it as a dict.
    gridworld.default_policyD = {(0, 0):'R',(1, 0):'U',(0, 1):'R',(0, 2):'R',
                                 (1, 2):'U','Start':'U',(2, 2):'U',(2, 1):'R',
                                 (2, 3):'L'}

    return gridworld

if __name__ == "__main__": # pragma: no cover
    
    gridworld = get_gridworld()
    #gridworld.summ_print()
    gridworld.layout_print(vname='reward', fmt='', show_env_states=True, none_str='*')

    gridworld.layout.s_hash_diagram( save_name='sample_diagram', 
                                     none_str='*', do_show=True,
                                     pad=0.05, scale=0.75, h_over_w=1.0)

    gridworld.save_to_pickle_file( fname=None )