dp_value_iter.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#!/usr/bin/env python
# -*- coding: ascii -*-
from __future__ import print_function
from __future__ import unicode_literals
from future import standard_library
standard_library.install_aliases()
from builtins import str
from builtins import range
from builtins import object

from introrl.policy import Policy
from introrl.state_values import StateValues
from introrl.utils.functions import argmax_vmax_dict, multi_argmax_vmax_dict
    
def dp_value_iteration( environment, allow_multi_actions=False,
                        do_summ_print=True, fmt_V='%g', fmt_R='%g',
                        max_iter=1000, err_delta=0.001, gamma=0.9,
                        iteration_prints=0):
    """
    ... GIVEN AN ENVIRONMENT ... 
    apply Value Iteration to find the OPTIMAL POLICY
    
    Returns: policy and state_value objects
    
    Terminates when delta < err_delta * VI_STOP_CRITERIA
    
    CREATES BOTH policy AND state_value OBJECTS.
    
    If allow_multi_actions is True, policy will include all actions 
    within err_delta of best action.
    """
    
    # create Policy and StateValues objects
    policy = Policy( environment=environment )
    policy.intialize_policy_to_random(env=environment)
        
    state_value = StateValues( environment )
    state_value.init_Vs_to_zero() # Terminal states need to be 0.0
    #state_value.summ_print()
    
    # set counter and flag
    loop_counter = 0
    all_done = False
       
    # value-iteration stopping criteria
    # if gamme==1.0 value iteration will never stop SO limit to gamma==0.999 stop criteria
    #  (VI terminates if delta < err_delta * VI_STOP_CRITERIA)
    #  (typically err_delta = 0.001)
    
    VI_STOP_CRITERIA = max((1.0-gamma) / gamma, (1.0-0.999)/0.999) 
    error_limit = err_delta * VI_STOP_CRITERIA
        
    while (loop_counter<max_iter) and (not all_done):
        loop_counter += 1
        all_done = True
        delta = 0.0 # used to calc largest change in state_value
        
        for s_hash in policy.iter_all_policy_states():
            VsD = {} # will hold: index=a_desc, value=V(s) for all transitions of a_desc from s_hash
            
            # MUST include currently zero prob actions
            for a_desc, a_prob in policy.iter_policy_ap_for_state( s_hash, incl_zero_prob=True):
                calcd_v = 0.0
                
                for sn_hash, t_prob, reward in \
                    environment.iter_next_state_prob_reward(s_hash, a_desc, incl_zero_prob=False):
                    
                    calcd_v += t_prob * ( reward + gamma * state_value(sn_hash) )
            
                VsD[a_desc] = calcd_v
            
            best_a_desc, best_a_val = argmax_vmax_dict( VsD )
            delta = max( delta, abs(best_a_val - state_value(s_hash)) )
            state_value[s_hash] = best_a_val
            
        if delta > error_limit:
            all_done = False
        
        if iteration_prints and (loop_counter % iteration_prints == 0):
            print('Loop:%6i'%loop_counter,'  delta=%g'%delta)
        
    # Now that State-Values have been determined, set policy
    for s_hash in policy.iter_all_policy_states():
        VsD = {} # will hold: index=a_desc, value=V(s) for all transitions of a_desc from s_hash
        
        # MUST include zero prob actions
        for a_desc, a_prob in policy.iter_policy_ap_for_state( s_hash, incl_zero_prob=True):
            calcd_v = 0.0
            
            for sn_hash, t_prob, reward in \
                environment.iter_next_state_prob_reward(s_hash, a_desc, incl_zero_prob=False):
                
                calcd_v += t_prob * ( reward + gamma * state_value(sn_hash) )
        
            VsD[a_desc] = calcd_v
        
        if allow_multi_actions:
            best_a_list, best_a_val = multi_argmax_vmax_dict( VsD, err_delta=err_delta )
            
            policy.set_sole_action( s_hash, best_a_list[0]) # zero all other actions
            prob = 1.0 / len(best_a_list)
            for a_desc in best_a_list:
                policy.set_action_prob( s_hash, a_desc, prob=prob)
        else:
            best_a_desc, best_a_val = argmax_vmax_dict( VsD )
            policy.set_sole_action( s_hash, best_a_desc)
    
    
    if do_summ_print:
        s = ''
        if loop_counter >= max_iter:
            s = '   (NOTE: STOPPED ON MAX-ITERATIONS)'

        print( 'Exited Value Iteration', s )
        print( '   iterations     =', loop_counter, ' (limit=%i)'%max_iter )
        print( '   measured delta =', delta )
        print( '   gamma          =', gamma )
        print( '   err_delta      =', err_delta )
        print( '   error limit    =',error_limit )
        print( '   STOP CRITERIA  =',VI_STOP_CRITERIA)
    
        state_value.summ_print( fmt_V=fmt_V )
        policy.summ_print(  environment=environment, verbosity=0, show_env_states=False  )
        
        environment.layout_print( vname='reward', fmt=fmt_R, show_env_states=False, none_str='*')
        
    return policy, state_value

if __name__ == "__main__": # pragma: no cover
    
    from introrl.mdp_data.simple_grid_world import get_gridworld
    
    gridworld = get_gridworld()
    
    policy, state_value = dp_value_iteration( gridworld, do_summ_print=True,
                                              max_iter=1000, err_delta=0.001, 
                                              gamma=0.9)