JAX Integration (Custom Training)#
Use this guide if you are not using the provided PPO and want to integrate the viewer into your own JAX training loop with jax.experimental.io_callback
.
Setup#
Start the viewer server (see examples or create one with
WebViewer
).Wrap your environment and create a sender:
from braxviewer.BraxSender import BraxSender
from braxviewer.wrapper import ViewerWrapper
sender = BraxSender(host="127.0.0.1", port=8000, xml=xml_string, num_envs=num_envs)
sender.start()
env = ViewerWrapper(env=env, sender=sender)
Core concepts#
env.render_fn(state)
: Python callback that sends a frame. Safe to call viaio_callback
; handles single or batchedState
.env.should_render
: device boolean (jax.Array
) reflecting the server toggle. Use it to guard rendering in JIT code (no recompilation).In your JAX loop, wrap the callback with
lax.cond(should_render, ...)
to avoid side effects when disabled.
Example: render hook in a lax.scan
loop#
import jax
import jax.numpy as jnp
from jax.experimental import io_callback
def train_step(env, policy, state, key, num_steps,
render_fn=None, should_render=jnp.array(False, jnp.bool_)):
def body(carry, _):
st, k = carry
k, k2 = jax.random.split(k)
action, _ = policy(st.obs, k)
nst = env.step(st, action)
def do_render(s):
if render_fn is not None:
io_callback(render_fn, None, s)
jax.lax.cond(should_render, do_render, lambda s: None, nst)
return (nst, k2), None
(final_state, _), _ = jax.lax.scan(body, (state, key), None, length=num_steps)
return final_state
Passing the render signals#
should_render = env.should_render # jax.Array(bool)
render_fn = env.render_fn # Python callable
state = train_step(env, policy, state, key, num_steps, render_fn, should_render)