Reputation: 21333
I am new to stable-baselines3 and am trying to get a toy graph neural network problem to work. I previously had a bit flipping example using an array. The problem is this: given a list of 10 random bits and an operation which flips a bit find a way to flip bits to set them all to 1. Clearly you can do this by just flipping the bits that are currently 0 but the system has to learn this.
I would like to do the same thing where the input is simple linear graph with node weights instead of an array. I am not sure how to do this. The following code snippet will make a linear graph with 10 nodes, add node weights to each node and convert it to a dgl graph
import networkx as nx
import random
import dgl
# Create edges to add
edges = []
N = 10
for i in range(N-1):
edges.append((i, i+1))
# Create graph and convert it into a dgl graph
for i in range(len(G.nodes)):
G.nodes[i]['weight'] = random.choice([0,1])
dgl_graph = dgl.from_networkx(G, node_attrs=["weight"])
When I was using a linear array for the bit flipping example my environment was this:
import numpy as np
import gym from gym
import spaces
class GraphFlipEnv(gym.Env):
def init(self, array_length=10):
super(BitFlipEnv, self).init()
# Size of the 1D-grid
self.array_length = array_length
# Initialize the array of bits to be random
self.agent_pos = random.choices([0,1], k=array_length)
# Define action and observation space
# They must be gym.spaces objects
# Example when using discrete actions, we have two: left and right
self.action_space = spaces.Discrete(array_length)
# The observation will be the coordinate of the agent
# this can be described both by Discrete and Box space
self.observation_space = spaces.Box(low=0, high=1,
shape=(array_length,), dtype=np.uint8)
def reset(self): # Initialize the array to have random values self.time = 0
self.agent_pos = random.choices([0,1], k=self.array_length)
return np.array(self.agent_pos)
def step(self, action):
self.time += 1
if not 0 <= action < self.array_length:
raise ValueError("Received invalid action={} which is not part of the action space".format(action))
self.agent_pos[action] ^= 1 # flip the bit
if self.agent_pos[action] == 1:
reward = 1
reward = -1
done = all(self.agent_pos)
info = {}
return np.array(self.agent_pos), reward, done, info
def render(self, mode='console'):
def close(self):
The last few lines to complete the code in the array version are simply:
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
env = make_vec_env(lambda: BitFlipEnv(array_length=50), n_envs=12)
# Train the agent
model = PPO('MlpPolicy', env, verbose=1).learn(500000)
I can't use stable-baselines'
any more for the graph so what is the right way to get stable-baselines to interface with my dgl graph for this toy problem?
Upvotes: 2
Views: 617