dp_policy_eval.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/usr/bin/env python
# -*- coding: ascii -*-
from __future__ import print_function
from __future__ import unicode_literals
from future import standard_library
standard_library.install_aliases()
from builtins import str
from builtins import range
from builtins import object

    
def dp_policy_evaluation( policy, state_value, do_summ_print=True, fmt_V='%g',
                          max_iter=1000, err_delta=0.001, gamma=0.9):
    """
    ... GIVEN A POLICY TO EVALUATE  apply State-Value Policy Evaluation 
    
    Use Policy-Evaluation to find V(s), State-Value Function
    
    Terminates when delta < err_delta * VI_STOP_CRITERIA
    
    Assume that V(s), state_value, has been initialized prior to call.
    (Note tht the StateValues object has a reference to the Environment object)
    
    state_value WILL BE CHANGED... policy WILL NOT.
    
    This code takes the state_values all the way to their final values for this policy.
    More general policy evaluation only goes part-way before improving the policy.
    """
    
    loop_counter = 0
    all_done = False
       
    # value-iteration stopping criteria
    # if gamme==1.0 value iteration will never stop SO limit to gamma==0.999 stop criteria
    #  (VI terminates if delta < err_delta * VI_STOP_CRITERIA)
    #  (typically err_delta = 0.001)
    VI_STOP_CRITERIA = max((1.0-gamma) / gamma, (1.0-0.999)/0.999) 
    error_limit = err_delta * VI_STOP_CRITERIA
    
    # ==> Note: the reference to Environment object as "state_value.environment"
    Env = state_value.environment
    max_delta = 0.0
    
    while (loop_counter<max_iter) and (not all_done):
        loop_counter += 1
        all_done = True
        delta = 0.0 # used to calc largest change in state_value
        
        # policy evaluation 
        for s_hash in policy.iter_all_policy_states():
            
            calcd_v = 0.0
            for a_desc, a_prob in policy.iter_policy_ap_for_state( s_hash, incl_zero_prob=False):
                
                for sn_hash, t_prob, reward in \
                    Env.iter_next_state_prob_reward(s_hash, a_desc, incl_zero_prob=False):
                    
                    calcd_v += t_prob * a_prob * ( reward + gamma * state_value(sn_hash) )
            
            delta = max( delta, abs(calcd_v - state_value(s_hash)) )
            if delta > error_limit:
                all_done = False
                max_delta = max(max_delta, delta) # returned to caller
            
            state_value[s_hash] = calcd_v
    
    if do_summ_print:
        s = ''
        if loop_counter >= max_iter:
            s = '   (NOTE: STOPPED ON MAX-ITERATIONS)'

        print( 'Exited Policy Evaluation', s )
        print( '   iterations     =', loop_counter, ' (limit=%i)'%max_iter )
        print( '   measured delta =', delta )
        print( '   gamma          =', gamma )
        print( '   err_delta      =', err_delta )
        print( '   error limit    =',error_limit )
        print( '   STOP CRITERIA  =',VI_STOP_CRITERIA)
    
        state_value.summ_print( fmt_V=fmt_V )

    return max_delta

if __name__ == "__main__": # pragma: no cover
    
    from introrl.policy import Policy
    from introrl.mdp_data.simple_grid_world import get_gridworld
    from introrl.state_values import StateValues
    
    gridworld = get_gridworld()
    pi = Policy(  environment=gridworld  )
    pi.set_policy_from_piD( gridworld.get_default_policy_desc_dict() )
    
    sv = StateValues( gridworld )
    #sv.init_Vs_to_zero() # done when StateValues is created.
    
    dp_policy_evaluation( pi, sv, max_iter=1000, err_delta=0.001, gamma=0.9)