-
Notifications
You must be signed in to change notification settings - Fork 493
Expand file tree
/
Copy pathexample_torch.py
More file actions
199 lines (148 loc) · 6.35 KB
/
example_torch.py
File metadata and controls
199 lines (148 loc) · 6.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
###########################################################################
# Example Torch
#
# Optimizes the Rosenbrock function using the PyTorch Adam optimizer
# The Rosenbrock function is a non-convex function, and is often used
# to test optimization algorithms. The function is defined as:
# f(x, y) = (a - x)^2 + b * (y - x^2)^2
# where a = 1 and b = 100. The minimum value of the function is 0 at (1, 1).
#
# The example demonstrates how to set up a torch.autograd.Function to
# incorporate Warp kernel launches within a PyTorch graph.
###########################################################################
import numpy as np
import torch
import warp as wp
# Define the Rosenbrock function
@wp.func
def rosenbrock(x: float, y: float):
return (1.0 - x) ** 2.0 + 100.0 * (y - x**2.0) ** 2.0
@wp.kernel
def eval_rosenbrock(xs: wp.array[wp.vec2], z: wp.array[float]):
i = wp.tid()
x = xs[i]
z[i] = rosenbrock(x[0], x[1])
class Rosenbrock(torch.autograd.Function):
@staticmethod
def forward(ctx, xy, num_particles):
ctx.xy = wp.from_torch(xy, dtype=wp.vec2, requires_grad=True)
ctx.num_particles = num_particles
# allocate output
ctx.z = wp.zeros(num_particles, requires_grad=True)
wp.launch(kernel=eval_rosenbrock, dim=ctx.num_particles, inputs=[ctx.xy], outputs=[ctx.z])
return wp.to_torch(ctx.z)
@staticmethod
def backward(ctx, adj_z):
# map incoming Torch grads to our output variables
ctx.z.grad = wp.from_torch(adj_z)
wp.launch(
kernel=eval_rosenbrock,
dim=ctx.num_particles,
inputs=[ctx.xy],
outputs=[ctx.z],
adj_inputs=[ctx.xy.grad],
adj_outputs=[ctx.z.grad],
adjoint=True,
)
# return adjoint w.r.t. inputs
return (wp.to_torch(ctx.xy.grad), None)
class Example:
def __init__(self, headless=False, train_iters=10, num_particles=1500):
self.num_particles = num_particles
self.train_iters = train_iters
self.train_iter = 0
self.learning_rate = 5e-2
self.torch_device = wp.device_to_torch(wp.get_device())
rng = np.random.default_rng(42)
self.xy = torch.tensor(
rng.normal(size=(self.num_particles, 2)), dtype=torch.float32, requires_grad=True, device=self.torch_device
)
self.xp_np = self.xy.numpy(force=True)
self.opt = torch.optim.Adam([self.xy], lr=self.learning_rate)
if headless:
self.scatter_plot = None
self.mean_marker = None
else:
self.scatter_plot = self.create_plot()
self.mean_pos = np.empty((2,))
def create_plot(self):
import matplotlib.pyplot as plt # noqa: PLC0415
min_x, max_x = -2.0, 2.0
min_y, max_y = -2.0, 2.0
# Create a grid of points
x = np.linspace(min_x, max_x, 100)
y = np.linspace(min_y, max_y, 100)
X, Y = np.meshgrid(x, y)
xy = np.column_stack((X.flatten(), Y.flatten()))
N = len(xy)
xy = wp.array(xy, dtype=wp.vec2)
Z = wp.empty(N, dtype=float)
wp.launch(eval_rosenbrock, dim=N, inputs=[xy], outputs=[Z])
Z = Z.numpy().reshape(X.shape)
# Plot the function as a heatmap
self.fig = plt.figure(figsize=(6, 6))
ax = plt.gca()
plt.imshow(Z, extent=[min_x, max_x, min_y, max_y], origin="lower", interpolation="bicubic", cmap="coolwarm")
plt.contour(X, Y, Z, extent=[min_x, max_x, min_y, max_y], levels=150, colors="k", alpha=0.5, linewidths=0.5)
# Plot optimum
plt.plot(1, 1, "*", color="r", markersize=10)
plt.title("Rosenbrock function")
plt.xlabel("x")
plt.ylabel("y")
(self.mean_marker,) = ax.plot([], [], "o", color="w", markersize=5)
# Create a scatter plot (initially empty)
return ax.scatter([], [], c="k", s=2)
def forward(self):
self.z = Rosenbrock.apply(self.xy, self.num_particles)
def step(self):
self.opt.zero_grad()
self.forward()
self.z.backward(torch.ones_like(self.z))
self.opt.step()
# Update the scatter plot
self.xy_np = self.xy.numpy(force=True)
# Compute mean
self.mean_pos = np.mean(self.xy_np, axis=0)
print(f"\rIter {self.train_iter:5d} particle mean: {self.mean_pos[0]:.8f}, {self.mean_pos[1]:.8f} ", end="")
self.train_iter += 1
def render(self):
if self.scatter_plot is None:
return
self.scatter_plot.set_offsets(np.c_[self.xy_np[:, 0], self.xy_np[:, 1]])
self.mean_marker.set_data([self.mean_pos[0]], [self.mean_pos[1]])
# Function to update the scatter plot
def step_and_render(self, frame):
for _ in range(self.train_iters):
self.step()
self.render()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--device", type=str, default=None, help="Override the default Warp device.")
parser.add_argument("--num-frames", type=int, default=10000, help="Total number of frames.")
parser.add_argument("--train-iters", type=int, default=10, help="Total number of training iterations per frame.")
parser.add_argument(
"--num-particles", type=int, default=1500, help="Total number of particles to use in optimization."
)
parser.add_argument(
"--headless",
action="store_true",
help="Run in headless mode, suppressing the opening of any graphical windows.",
)
args = parser.parse_known_args()[0]
with wp.ScopedDevice(args.device):
example = Example(headless=args.headless, train_iters=args.train_iters, num_particles=args.num_particles)
if not args.headless:
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
# Create the animation
ani = FuncAnimation(
example.fig, example.step_and_render, frames=args.num_frames, interval=100, repeat=False
)
# Display the animation
plt.show()
else:
for _ in range(args.num_frames):
example.step()