dp_policy_improve.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
#!/usr/bin/env python
# -*- coding: ascii -*-
from __future__ import print_function
from __future__ import unicode_literals
from future import standard_library
standard_library.install_aliases()
from builtins import str
from builtins import range
from builtins import object


from introrl.utils.functions import argmax_vmax_dict
    
def dp_policy_improvement( policy, state_value, gamma=0.9, 
                      do_summ_print=True, max_iter=1000):
    """
    ... GIVEN STATE-VALUES ...  apply State-Value Policy Improvement
    
    Use Policy-Improvement to find best policy for current V(s) values
    
    Terminates when policy is stable.
    
    Assume that V(s), state_value, has been initialized prior to call.
    (Note tht the StateValues object has a reference to the Environment object)
    
    policy WILL BE CHANGED... state_value WILL NOT.
    """
    
    loop_counter = 0
    is_stable = False
    made_changes = False
           
    # Note: the reference to Environment object as "state_value.environment"
    Env = state_value.environment
    
    while (loop_counter<max_iter) and (not is_stable):
        loop_counter += 1
        is_stable = True
        
        # policy improvement
        for s_hash in policy.iter_all_policy_states():
            old_action = policy.get_single_action( s_hash )
            
            VsD = {} # will hold: index=a_desc, value=V(s) for all transitions of a_desc from s_hash
            for a_desc, a_prob in policy.iter_policy_ap_for_state( s_hash, incl_zero_prob=True):
                VsD[a_desc] = 0.0
                for sn_hash, t_prob, reward in \
                    Env.iter_next_state_prob_reward(s_hash, a_desc, incl_zero_prob=False):
                    
                    # need to assume that a_prob==1.0
                    #VsD[a_desc] += t_prob * a_prob * ( reward + gamma * state_value(sn_hash) )
                    VsD[a_desc] += t_prob * ( reward + gamma * state_value(sn_hash) )

            # use pick_random_best=False to avoid subtle non-termination bug (see page 82)
            best_a_desc, best_a_val = argmax_vmax_dict( VsD, pick_random_best=False )
            
            if best_a_desc != old_action:
                is_stable = False
                made_changes = True # returned to caller
            
            policy.set_sole_action( s_hash, best_a_desc)
    
    if do_summ_print:
        s = ''
        if loop_counter >= max_iter:
            s = '   (NOTE: STOPPED ON MAX-ITERATIONS)'
        print( '=========================' + '='*len(s) )
        print( 'Exited Policy Improvement', s )
        print( '   iterations    =', loop_counter, ' (limit=%i)'%max_iter )
        print( '   gamma         =', gamma )
        print( '=========================' + '='*len(s) )
    
        state_value.summ_print()
        
    return made_changes

if __name__ == "__main__": # pragma: no cover
    
    from introrl.policy import Policy
    from introrl.state_values import StateValues
    from introrl.dp_funcs.dp_policy_eval import dp_policy_evaluation
    from introrl.mdp_data.simple_grid_world import get_gridworld
    
    gridworld = get_gridworld()
    pi = Policy(  environment=gridworld  )
    pi.set_policy_from_piD( gridworld.get_default_policy_desc_dict() )
    
    print('-'*55)
    sv = StateValues( gridworld )
    sv.init_Vs_to_zero()
    
    dp_policy_evaluation( pi, sv, max_iter=1000, err_delta=0.001, gamma=0.9, do_summ_print=False)

    print('-'*55)
    pi_2 = Policy(  environment=gridworld  )
    pi_2.intialize_policy_to_random( env=gridworld )
    print('-------- Random Policy Prior to Improvement ----------')
    pi_2.summ_print( verbosity=0 )
    
    # sv should be optimum for the "pi" policy... see if "pi_2" is changed to "pi"
    dp_policy_improvement( pi_2, sv, gamma=0.9, do_summ_print=True, max_iter=1000)
    
    print('-------- Random Policy AFTER Improvement ----------')
    pi_2.summ_print(  environment=gridworld, verbosity=0, show_env_states=False  )
    print('-------- Default gridworld policy ----------')
    pi.summ_print(  environment=gridworld, verbosity=0, show_env_states=False  )