Skip to content

Commit f8663ea

Browse files
committed
fixes acktr_cont issues
1 parent 699919f commit f8663ea

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

baselines/acktr/acktr_cont.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def rollout(env, policy, max_pathlength, animate=False, obfilter=None):
4646
"action_dist": np.array(ac_dists), "logp" : np.array(logps)}
4747

4848
def learn(env, policy, vf, gamma, lam, timesteps_per_batch, num_timesteps,
49-
animate=False, callback=None, optimizer="adam", desired_kl=0.002):
49+
animate=False, callback=None, desired_kl=0.002):
5050

5151
obfilter = ZFilter(env.observation_space.shape)
5252

@@ -117,14 +117,16 @@ def learn(env, policy, vf, gamma, lam, timesteps_per_batch, num_timesteps,
117117
# Policy update
118118
do_update(ob_no, action_na, standardized_adv_n)
119119

120+
min_stepsize = np.float32(1e-8)
121+
max_stepsize = np.float32(1e0)
120122
# Adjust stepsize
121123
kl = policy.compute_kl(ob_no, oldac_dist)
122124
if kl > desired_kl * 2:
123125
logger.log("kl too high")
124-
U.eval(tf.assign(stepsize, stepsize / 1.5))
126+
U.eval(tf.assign(stepsize, tf.maximum(min_stepsize, stepsize / 1.5)))
125127
elif kl < desired_kl / 2:
126128
logger.log("kl too low")
127-
U.eval(tf.assign(stepsize, stepsize * 1.5))
129+
U.eval(tf.assign(stepsize, tf.minimum(max_stepsize, stepsize * 1.5)))
128130
else:
129131
logger.log("kl just right!")
130132

baselines/acktr/run_mujoco.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,4 @@ def train(env_id, num_timesteps, seed):
3939
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
4040
parser.add_argument('--env', help='environment ID', type=str, default="Reacher-v1")
4141
args = parser.parse_args()
42-
train(args.env_id, num_timesteps=1e6, seed=args.seed)
42+
train(args.env, num_timesteps=1e6, seed=args.seed)

baselines/acktr/value_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def __init__(self, ob_dim, ac_dim): #pylint: disable=W0613
1313
wd_dict = {}
1414
h1 = tf.nn.elu(dense(X, 64, "h1", weight_init=U.normc_initializer(1.0), bias_init=0, weight_loss_dict=wd_dict))
1515
h2 = tf.nn.elu(dense(h1, 64, "h2", weight_init=U.normc_initializer(1.0), bias_init=0, weight_loss_dict=wd_dict))
16-
vpred_n = dense(h2, 1, "hfinal", weight_init=U.normc_initializer(1.0), bias_init=0, weight_loss_dict=wd_dict)[:,0]
16+
vpred_n = dense(h2, 1, "hfinal", weight_init=None, bias_init=0, weight_loss_dict=wd_dict)[:,0]
1717
sample_vpred_n = vpred_n + tf.random_normal(tf.shape(vpred_n))
1818
wd_loss = tf.get_collection("vf_losses", None)
1919
loss = U.mean(tf.square(vpred_n - vtarg_n)) + tf.add_n(wd_loss)
@@ -22,7 +22,7 @@ def __init__(self, ob_dim, ac_dim): #pylint: disable=W0613
2222
optim = kfac.KfacOptimizer(learning_rate=0.001, cold_lr=0.001*(1-0.9), momentum=0.9, \
2323
clip_kl=0.3, epsilon=0.1, stats_decay=0.95, \
2424
async=1, kfac_update=2, cold_iter=50, \
25-
weight_decay_dict=wd_dict, max_grad_norm=None)
25+
weight_decay_dict=wd_dict, max_grad_norm=1.0)
2626
vf_var_list = []
2727
for var in tf.trainable_variables():
2828
if "vf" in var.name:

0 commit comments

Comments
 (0)