NeRF: A Volume Rendering Perspective
NeRF: A Volume Rendering Perspective
Overview
NeRF implicitly represents a 3D scene with a multi-layer perceptron (MLP) for some position , view direction , color , and "opacity" . Rendered results are spectacular.
There have been a number of articles introducing NeRF since its publication in 2020. While most posts mention general methods, few of them elaborate on why the volume rendering procedure
for a ray works and how the equation is reduced to
via numerical quadrature, let alone exploring its implementation via Monte Carlo method.
This post delves into the volume rendering aspect of NeRF. The equations will be derived; its implementation will be analyzed.
Prerequisites
- Having read the NeRF paper
- Ability to solve ordinary differential equations (ODEs)
- Introductory probability theory
- Familiarity with PyTorch and NumPy
Background
The rendering formula
A ray with origin and direction casts to an arbitrary region of bounded space. Assume, for simplicity, that the cross section area is uniform along the ray. Let's focus on a slice of the region with thickness .
Occluding objects are modeled as spherical particles with radius . Let denote the particle density — number of particles per unit volume — in that orientation. If is small enough such that is consistent on the slice, then there are
particles contained in the slice. When , solid particles do not overlap, then a total of
area is occluded, which amounts to a portion of of the cross section. Let denote the light intensity at depth from origin along direction . If a portion of of all rays are occluded, then the intensity at depth decreases to
The difference in intensity is
Take a step from discrete to continuous, we have
Define volume density (or voxel "opacity") . This makes sense because the amount of ray reduction depends on both the number of occluding particles and the size of them, then the solution to the ODE
is
Step-by-step solution
Exchange the terms at both sides of the ODE:
which is a separable DE. Integrate both sides
Suppose takes at depth , then
Define accumulated transmittance , then means the remainning intensity after the rays travels from to . can also be viewed as the cumulative density function (CDF) that a ray does not hit any particles from to . But no color will be observed if a ray passes empty space; radiance is "emitted" only when there is contact between rays and particles. Define
as the CDF that a ray hits particles from to , then its probability density function (PDF) is
CDF to PDF
Differentiate CDF (w.r.t. ) to get
Let a random variable a random variable denote the emitted radiance, then
Hence, the color of a pixel is the expectation of emitted radiance:
concluding the proof.
Integration bounds
In practice, , obtained from MLP query, is a function of both position (or coordinate ) and view direction . Also different are the integration bounds. A computer does not support an infinite range; the lower and upper bounds of integration are and within the range of floating point representation:
In NeRF, 0.
and 1.
for scaled bounded scenes and front facing scenes after conversion to normalized device coordinates (NDC).
Numerical quadrature
We took a step from discrete to continuous to derive the rendering integral. Nevertheless, integration on a continuous domain is not supported by computers. An alternative is numerical quadrature. Sample along a ray, and define differences between adjacent samples as
then the transmittance is approximated by
where and . Meanwhile, differentiation in is also substituted by discrete difference. That is,
Step-by-step solution
Hence, the discretized emmitted radiance is
where is the output RGB upon MLP query at .
Note that if we denote , then resembles classical alpha compositing.
Alpha compositing
Consider the case where a foreground object is inserted ahead of the background. Now that a pixel displays a single color, we have to blend the colors of the two. Compositing is applied when there are partially transparent regions within the foreground object, or the foreground object partially covers the background.
In alpha compositing, a parameter determines the extent to which each object contributes to what is displayed in a pixel. Let denote the opacity (or pixel coverage) of the foreground object, then a pixel showing foreground color over background color is composited as
When blending colors of multiple objects, one can adopt a divide-and-conquer approach. Each time, cope with the unregistered object closest to the eye and treat the remaining objects as a single entity. Such a strategy is formulated by
which is essentially a tail recursion.
Alpha compositing in NeRF
There are samples along each ray in NeRF. Consider the first samples as occluding foreground objects with opacity and color , and the last sample as background, then the blended pixel value is
where . Recall , then it remains to show that :
concluding the proof. This manifests the elegancy of differentiable volume rendering.
Rewrite , then the expectation of emmitted radiance is weighted sum of colors.
Why (trivially) differentiable?
Given the above renderer, a coarse training pipeline is
If the discrete renderer is differentiable, then we can train the end-to-end model through gradient descent. No suprise, given a (sorted) sequence of random samples , the derivatives are
Once the renderer is differentiable, weights and biases in an MLP can be updated via the chain rule.
Coarse-to-fine approach
NeRF jointly optimizes coarse and fine network.
Analysis
Whereas NeRF is originally implemented in Tensrorflow, code analysis is based on a faithful reproduction in PyTorch. The repository is organized as
Let's experiment with the LLFF dataset, which is comprised of front-facing scenes with camera poses. Pertinent directories and files are
Item | Type | Description |
---|---|---|
configs | directory | contains per scene configuration (.txt ) for the LLFF dataset |
download_example_data.sh | shell script | to download datasets |
load_llff.py | Python script | data loader of the LLFF dataset |
run_nerf.py | Prthon script | main procedures |
run_nerf_helpers.py | Python script | utility functions |
Modified identation and comments
Codes in this post deviate slightly from the authentic version. Dataflow and function calls remain intact whereas indentation and comments are modified for the sake of readibility.
The big picture
if __name__ == '__main__':
torch.set_default_tensor_type('torch.cuda.FloatTensor')
train()
As shown, train(…)
in run_nerf.py
is the execution entry to the project. The entire training process is
def train():
parser = config_parser()
args = parser.parse_args()
# load data
K = None
if args.dataset_type == 'llff':
images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,
recenter=True, bd_factor=.75,
spherify=args.spherify)
hwf = poses[0,:3,-1]
poses = poses[:,:3,:4]
print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)
if not isinstance(i_test, list):
i_test = [i_test]
if args.llffhold > 0:
print('Auto LLFF holdout,', args.llffhold)
i_test = np.arange(images.shape[0])[ : :args.llffhold]
i_val = i_test
i_train = np.array([i for i in np.arange(int(images.shape[0]))
if (i not in i_test and i not in i_val)])
print('DEFINING BOUNDS')
if args.no_ndc:
near = np.ndarray.min(bds) * .9
far = np.ndarray.max(bds) * 1.
else:
near = 0.
far = 1.
print('NEAR FAR', near, far)
elif args.dataset_type == 'blender':
images, poses, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, args.testskip)
print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir)
i_train, i_val, i_test = i_split
near = 2.
far = 6.
if args.white_bkgd:
images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
else:
images = images[...,:3]
elif args.dataset_type == 'LINEMOD':
images, poses, render_poses, hwf, K, i_split, near, far = load_LINEMOD_data(args.datadir, args.half_res, args.testskip)
print(f'Loaded LINEMOD, images shape: {images.shape}, hwf: {hwf}, K: {K}')
print(f'[CHECK HERE] near: {near}, far: {far}.')
i_train, i_val, i_test = i_split
if args.white_bkgd:
images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
else:
images = images[...,:3]
elif args.dataset_type == 'deepvoxels':
images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape,
basedir=args.datadir,
testskip=args.testskip)
print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir)
i_train, i_val, i_test = i_split
hemi_R = np.mean(np.linalg.norm(poses[:,:3,-1], axis=-1))
near = hemi_R-1.
far = hemi_R+1.
else:
print('Unknown dataset type', args.dataset_type, 'exiting')
return
# cast intrinsics to right types
H, W, focal = hwf
H, W = int(H), int(W)
hwf = [H, W, focal]
if K is None:
K = np.array([
[focal, 0 , 0.5*W],
[0 , focal, 0.5*H],
[0 , 0 , 1 ]
])
if args.render_test:
render_poses = np.array(poses[i_test])
# create log dir and copy the config file
basedir = args.basedir
expname = args.expname
os.makedirs(os.path.join(basedir, expname), exist_ok=True)
f = os.path.join(basedir, expname, 'args.txt')
with open(f, 'w') as file:
for arg in sorted(vars(args)):
attr = getattr(args, arg)
file.write('{} = {}\n'.format(arg, attr))
if args.config is not None:
f = os.path.join(basedir, expname, 'config.txt')
with open(f, 'w') as file:
file.write(open(args.config, 'r').read())
# create nerf model
render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args)
global_step = start
bds_dict = {'near': near,
'far' : far}
render_kwargs_train.update(bds_dict)
render_kwargs_test.update(bds_dict)
# move test data to GPU
render_poses = torch.Tensor(render_poses).to(device)
# short circuit if rendering from trained model
if args.render_only:
print('RENDER ONLY')
with torch.no_grad():
if args.render_test:
images = images[i_test] # switch to test poses
else:
# default is smoother render_poses path
images = None
testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format('test' if args.render_test else 'path', start))
os.makedirs(testsavedir, exist_ok=True)
print('test poses shape', render_poses.shape)
rgbs, _ = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor)
print('Done rendering', testsavedir)
imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8)
return
# prepare raybatch tensor if batching random rays
N_rand = args.N_rand
use_batching = not args.no_batching
if use_batching:
# random ray batching
print('get rays')
rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0) # (num_img, ro+rd, H, W, 3)
print('done, concats')
rays_rgb = np.concatenate([rays, images[:,None]], 1) # (num_img, ro+rd+rgb, H, W, 3)
rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # (num_img, H, W, ro+rd+rgb, 3)
rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # training set only
rays_rgb = np.reshape(rays_rgb, [-1,3,3]) # ((num_img-1)*H*W, ro+rd+rgb, 3)
rays_rgb = rays_rgb.astype(np.float32)
print('shuffle rays')
np.random.shuffle(rays_rgb)
print('done')
i_batch = 0
# move training data to GPU
if use_batching:
images = torch.Tensor(images).to(device)
poses = torch.Tensor(poses).to(device)
if use_batching:
rays_rgb = torch.Tensor(rays_rgb).to(device)
N_iters = 200000 + 1
print('Begin')
print('TRAIN views are', i_train)
print('TEST views are', i_test)
print('VAL views are', i_val)
# summary writers
#writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))
start = start + 1
for i in trange(start, N_iters):
time0 = time.time()
# sample random ray batch
if use_batching:
# random over all images
batch = rays_rgb[i_batch:i_batch+N_rand] # (B, 2+1, 3*?)
batch = torch.transpose(batch, 0, 1)
batch_rays, target_s = batch[:2], batch[2]
i_batch += N_rand
if i_batch >= rays_rgb.shape[0]:
print("Shuffle data after an epoch!")
rand_idx = torch.randperm(rays_rgb.shape[0])
rays_rgb = rays_rgb[rand_idx]
i_batch = 0
else:
# random from one image
img_i = np.random.choice(i_train)
target = images[img_i]
target = torch.Tensor(target).to(device)
pose = poses[img_i, :3,:4]
if N_rand is not None:
rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose)) # (H, W, 3), (H, W, 3)
if i < args.precrop_iters:
dH = int(H//2 * args.precrop_frac)
dW = int(W//2 * args.precrop_frac)
coords = torch.stack(
torch.meshgrid(
torch.linspace(H//2 - dH, H//2 + dH - 1, 2*dH),
torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW)
), -1)
if i == start:
print(f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}")
else:
coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2)
coords = torch.reshape(coords, [-1,2]) # (H * W, 2)
select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
select_coords = coords[select_inds].long() # (N_rand, 2)
rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
batch_rays = torch.stack([rays_o, rays_d], 0)
target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
##### core optimization loop #####
rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,
verbose=i < 10, retraw=True,
**render_kwargs_train)
optimizer.zero_grad()
img_loss = img2mse(rgb, target_s)
trans = extras['raw'][...,-1]
loss = img_loss
psnr = mse2psnr(img_loss)
if 'rgb0' in extras:
img_loss0 = img2mse(extras['rgb0'], target_s)
loss = loss + img_loss0
psnr0 = mse2psnr(img_loss0)
loss.backward()
optimizer.step()
# NOTE: IMPORTANT!
### update learning rate ###
decay_rate = 0.1
decay_steps = args.lrate_decay * 1000
new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
for param_group in optimizer.param_groups:
param_group['lr'] = new_lrate
################################
dt = time.time()-time0
# print(f"Step: {global_step}, Loss: {loss}, Time: {dt}")
##### end #####
# rest is logging
if i%args.i_weights==0:
path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))
torch.save({
'global_step': global_step,
'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(),
'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, path)
print('Saved checkpoints at', path)
if i%args.i_video==0 and i > 0:
# test mode
with torch.no_grad():
rgbs, disps = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test)
print('Done, saving', rgbs.shape, disps.shape)
moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8)
imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8)
#if args.use_viewdirs:
# render_kwargs_test['c2w_staticcam'] = render_poses[0][:3,:4]
# with torch.no_grad():
# rgbs_still, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test)
# render_kwargs_test['c2w_staticcam'] = None
# imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8)
if i%args.i_testset==0 and i > 0:
testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i))
os.makedirs(testsavedir, exist_ok=True)
print('test poses shape', poses[i_test].shape)
with torch.no_grad():
render_path(torch.Tensor(poses[i_test]).to(device), hwf, K, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir)
print('Saved test set')
if i%args.i_print==0:
tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}")
"""
print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy())
print('iter time {:.05f}'.format(dt))
with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print):
tf.contrib.summary.scalar('loss', loss)
tf.contrib.summary.scalar('psnr', psnr)
tf.contrib.summary.histogram('tran', trans)
if args.N_importance > 0:
tf.contrib.summary.scalar('psnr0', psnr0)
if i%args.i_img==0:
# log a rendered validation view to Tensorboard
img_i=np.random.choice(i_val)
target = images[img_i]
pose = poses[img_i, :3,:4]
with torch.no_grad():
rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose,
**render_kwargs_test)
psnr = mse2psnr(img2mse(rgb, target))
with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
tf.contrib.summary.image('rgb', to8b(rgb)[tf.newaxis])
tf.contrib.summary.image('disp', disp[tf.newaxis,...,tf.newaxis])
tf.contrib.summary.image('acc', acc[tf.newaxis,...,tf.newaxis])
tf.contrib.summary.scalar('psnr_holdout', psnr)
tf.contrib.summary.image('rgb_holdout', target[tf.newaxis])
if args.N_importance > 0:
with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
tf.contrib.summary.image('rgb0', to8b(extras['rgb0'])[tf.newaxis])
tf.contrib.summary.image('disp0', extras['disp0'][tf.newaxis,...,tf.newaxis])
tf.contrib.summary.image('z_std', extras['z_std'][tf.newaxis,...,tf.newaxis])
"""
global_step += 1
which is lengthy and potentially obscure. Function calls are visualized below, which assists comprehension of the rendering pipeline.
As captioned, let's concentrate on implementating volume rendering in this post. Our journey starts from rays generation (get_rays_np(…)
at line ) and culminates in learning rate update (lines to ) in each iteration.
Prior to rendering is the data loader (lines to ) and network initialization (lines to ). We ommit the analysis of the data loader. Though loaded images and poses will be introduced upon their first appearance, feel free to peruse this post for details. Neither will we delve into lines to , which terminates the project immediately after rendering a novel view or video. Testing NeRF is not our primary concern. Despite this, network creation and initialization will be covered in appendix.
Functions analyses are organized in horizontal tabs.
As you see in the function flow chart, procedure call is complex in NeRF. To facilitate clearity, there will be a few horizontal tabs in a section, each responsible for a single function.
Training set
# prepare raybatch tensor if batching random rays
N_rand = args.N_rand
use_batching = not args.no_batching
if use_batching:
# random ray batching
print('get rays')
rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0) # (num_img, ro+rd, H, W, 3)
print('done, concats')
rays_rgb = np.concatenate([rays, images[:,None]], 1) # (num_img, ro+rd+rgb, H, W, 3)
rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # (num_img, H, W, ro+rd+rgb, 3)
rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # training set only
rays_rgb = np.reshape(rays_rgb, [-1,3,3]) # ((num_img-1)*H*W, ro+rd+rgb, 3)
rays_rgb = rays_rgb.astype(np.float32)
print('shuffle rays')
np.random.shuffle(rays_rgb)
print('done')
i_batch = 0
There are command line argument variables (CL args) in the above snippet:
Variable | Value | Description |
---|---|---|
N_rand | by default | batch size: number of random rays per optimization loop |
no_batching | False by default | whether or not adopt rays from a single image per iteration |
use_batching
, therefore, is asserted by default. The conditioned block contains in lines to a few alien variables, most of which are relevant to the dataloader (lines to in train(…)
):
Variable | Type | Dimension | Description |
---|---|---|---|
H | int | height of image plane in pixels | |
W | int | width of image plane in pixles | |
K | NumPy array | , where is the focal length of the camera, is a calibration matrix, also the camera intrinsics. It is defined from line to in train(…) . | |
poses | NumPy array | all camera poses, where is the number of images in a scene | |
images | NumPy array | all images | |
i_train | NumPy array | indices of training images, i_train i_test i_test is initially provided by the dataloader (line in train(…) ); it is then overridden by lines to since args.llffhold is by default. |
Camera intrinsics
The calibration matrix takes a general form
for some aspect ratio , skew , and principle point .
unless pixels are not square. " encodes possible skew between the sensor axes due to the sensor not being mounted perpendicular to the optical axis." denotes the image center in pixel coordinates. In practice, is simplified to
get_rays_np(…)
is then invoked at line to generate rays (see right tab). Iterating all images, rays
has shape . Lines and packs rays_o
, rays_d
, and images
together with their dimension changed to . Lines to filter and shuffle rays in the training set, whose final result rays_rgb
is of dimension .
Misleading comment
Training set dimension is commented to be at line , which implies only image in a scene is for testing. This is not true for the LLFF dataset. Behavior of the dataloader is overridden by lines to in train(…)
.
get_rays_np(…)
is called by the line rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0)
, where H
and W
are respectively the height and width of the image plane, and K
is the camera intrinsics. p
is more physically involved, detailed below.
Suppose world frame (canonical coordinates) is characterized by an orthonormal basis and an origin , and that camera space is defined by an orthonormal basis and an origin . Denote camera space parameters w.r.t. canonical coordinates as
then p
is the frame-to-canonical matrix that maps rays in camera space to world coordinates
def get_rays_np(H, W, K, c2w):
i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
dirs = np.stack([ (i-K[0][2]) / K[0][0],
-(j-K[1][2]) / K[1][1],
-np.ones_like(i) ], -1)
# rotate ray directions from camera frame to the world frame
rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
# translate camera frame's origin to the world frame. It is the origin of all rays.
rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d))
return rays_o, rays_d
np.meshgrid(…)
at line creates a 2D grid of points , where arrays i
and j
are respectively the - and coordinates of the grid points. However, the image plane is bounded by
Applying an offset, pixel coordinates are in camera frame. A ray is defined by an origin , which lies at the origin , and its direction that connects the origin to a pixel. For every pixel on the image plane, an ejective direction is
which is then "normalized" to . This corresponds to lines to , except that dirs
for an entire image are generated concurrently. Note that -axis is opposite to -axis. Consequently, line adopts (instead of ) as -coordinate of .
Rays are then mapped to world space. Apply a linear transformation c2w
to directions to obtain rays_d
(line ) . Ray origins rays_o
are simply , the last column of c2w
(line ). It is broadcast to match the dimension of rays_d
.
Coordinate transformation
A point in camera coordinates is characterized by
alternatively, the same point in world space is
then and are bridged by
Check 3B1B's videos (linear transformation, matrix multiplication, and 3D trnsformation) on linear algebra to comprehend what the above linear transformations (matrices) physically mean.
The returned values are
Variable | Type | Dimension | Description |
---|---|---|---|
rays_o | NumPy array | directions of rays in the image plane | |
rays_d | NumPy array | origins of rays in the image plane |
Training preparation
def train_prepare(self, data):
…
# move training data to GPU
if use_batching:
images = torch.Tensor(images).to(device)
poses = torch.Tensor(poses).to(device)
if use_batching:
rays_rgb = torch.Tensor(rays_rgb).to(device)
N_iters = 200000 + 1
print('Begin')
print('TRAIN views are', i_train)
print('TEST views are', i_test)
print('VAL views are', i_val)
# summary writers
#writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))
start = start + 1
for i in trange(start, N_iters):
time0 = time.time()
# sample random ray batch
if use_batching:
# random over all images
batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?]
batch = torch.transpose(batch, 0, 1)
batch_rays, target_s = batch[:2], batch[2]
i_batch += N_rand
if i_batch >= rays_rgb.shape[0]:
print("Shuffle data after an epoch!")
rand_idx = torch.randperm(rays_rgb.shape[0])
rays_rgb = rays_rgb[rand_idx]
i_batch = 0
else:
…
Lines to convert the training set (from NumPy arrays) to PyTorch tensors and "transfer" them to GPU RAM, and start = start + 1
(line ) marks the commencement of training iterations. Training data were first divided into batches. Let denote the batch size ( N_rand
), then inputs batch_rays
and ground truth target_s
have shape and (line ). Lines to handles out-of-bound cases where the index i_batch
exceeds . We do not care about the else
block starting from line since use_batching
is asserted by default.
Rendering
…
for i in trange(start, N_iters):
…
##### core optimization loop #####
rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,
verbose=i < 10, retraw=True,
**render_kwargs_train)
…
Ensuing is volume rendering. CL arg chunk
defines the number of rays concurrently processed, which impacts performance rather than correctness. render_kwargs_train
is a dictionary returned upon initiating a NeRF network (line in train
) with more keys injected at line (in train
). Its internals are
Key | Element | Description |
---|---|---|
network_query_fn | a function | a subroutine that takes data and a network as input to perform query |
perturb | 1. | whether to adopt stratified sampling, 1. for True |
N_importance | number of addition samples per ray in hierarchical sampling | |
network_fine | an object | the fine network |
N_samples | number of samples per ray to coarse network | |
network_fn | an object | the coarse network |
use_viewdirs | True | whether to feed viewing directions to network, indispensible for view-dependent apprearance |
white_bkgd | False | whether to assume white background for rendering This applies to the synthetic dataset only, which contains images ( .png ) with transparent background. |
raw_noise_std | 1. | magnitude of noise to inject into volume density |
near | 0. | lower bound of rendering integration |
far | 1. | upper bound of rendering integration |
Note
See appendix for how a NeRF model is implemented.
def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
near=0., far=1.,
use_viewdirs=False, c2w_staticcam=None,
**kwargs):
if c2w is not None:
# special case to render full image
rays_o, rays_d = get_rays(H, W, K, c2w)
else:
# use provided ray batch
rays_o, rays_d = rays
# provide ray directions as input
if use_viewdirs:
viewdirs = rays_d
if c2w_staticcam is not None:
# special case to visualize effect of viewdirs
rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
viewdirs = torch.reshape(viewdirs, [-1,3]).float()
sh = rays_d.shape # shape: … × 3
# for forward facing scenes
if ndc:
rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)
# create ray batch
rays_o = torch.reshape(rays_o, [-1,3]).float()
rays_d = torch.reshape(rays_d, [-1,3]).float()
near, far = near * torch.ones_like(rays_d[...,:1]), \
far * torch.ones_like(rays_d[...,:1])
rays = torch.cat([rays_o, rays_d, near, far], -1)
if use_viewdirs:
rays = torch.cat([rays, viewdirs], -1)
# render and reshape
all_ret = batchify_rays(rays, chunk, **kwargs)
for k in all_ret:
k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
all_ret[k] = torch.reshape(all_ret[k], k_sh)
k_extract = ['rgb_map', 'disp_map', 'acc_map']
ret_list = [all_ret[k] for k in k_extract]
ret_dict = {k : all_ret[k] for k in all_ret
if k not in k_extract}
return ret_list + [ret_dict]
Lines to and to are ignored because the conditions contradict the default setting. Rays are unpacked to origins and directions at line . Viewing directions are aliases to ray directions (line ) except that they are normalized at line . Rays are then projected to NDC space at line ; see my other post for details. near
and far
are initiated (lines and ) to match the shape of rays_d
. All the above are concatenated (lines to ) such that input to the network rays
have dimension .
batchify_rays(…)
at line (see right tab) decomposes the input tensor into mini-batches to feed to the NeRF network sequentially. There are elements in the all_ret
dictionary at line :
Key | Description of element |
---|---|
rgb_map | output color map of the fine network |
disp_map | output disparity map of the fine network |
acc_map | output accumulated transmittance of the fine network |
raw | raw output of the fine network |
rgb0 | output color map of the coarse network |
disp0 | output disparity map of the coarse network |
acc0 | output accumulated transmittance of the coarse network |
z_std | standard variance of disparities of samples along each ray |
The ensuing lines (line onward) group and reorder the output such that what are returned can be properly unpacked by train(…)
def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
""" render rays in smaller minibatches to avoid OOM
"""
all_ret = {}
for i in range(0, rays_flat.shape[0], chunk):
ret = render_rays(rays_flat[i:i+chunk], **kwargs)
for k in ret:
if k not in all_ret:
all_ret[k] = []
all_ret[k].append(ret[k])
all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret}
return all_ret
Mini-batches are sequentially passed to render_rays(…)
(see below) at line , cached from line to , and eventually concatenated at line . We may consider batchify_rays(…)
as a broker connecting the high-level interface (render(…)
) to actual rendering implementation.
Keyword arguments are passed from high-level interface to "worker" procedures.
By encapsulation with batchify_rays(…)
and render(…)
, training options render_kwargs_train
defined previously are passed to the low-level "worker" render_rays(…)
, which is to core of volume rendering.
def render_rays(ray_batch,
network_fn,
network_query_fn,
N_samples,
retraw=False,
lindisp=False,
perturb=0., # 1.0, overridden by input
N_importance=0,
network_fine=None,
white_bkgd=False,
raw_noise_std=0.,
verbose=False,
pytest=False):
N_rays = ray_batch.shape[0]
rays_o, rays_d = ray_batch[:,0:3], \
ray_batch[:,3:6] # (ray #, 3)
viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 \
else None
bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2])
near, far = bounds[...,0], \
bounds[...,1] # (ray #, 1)
t_vals = torch.linspace(0., 1., steps=N_samples)
if not lindisp:
z_vals = near * (1. - t_vals) + far * t_vals
else:
z_vals = 1. / (1./near * (1. - t_vals) +
1./far * ( t_vals) )
# copy sample distances of 1 ray to the others
z_vals = z_vals.expand([N_rays, N_samples])
if perturb > 0.:
# get intervals between samples
mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
upper = torch.cat([mids, z_vals[...,-1:]], -1)
lower = torch.cat([z_vals[...,:1], mids], -1)
# stratified samples in those intervals
t_rand = torch.rand(z_vals.shape)
# pytest: overwrite U with fixed NumPy random numbers
if pytest:
np.random.seed(0)
t_rand = np.random.rand(*list(z_vals.shape))
t_rand = torch.Tensor(t_rand)
z_vals = lower + (upper - lower) * t_rand
pts = rays_o[..., None, :] + \
rays_d[..., None, :] * z_vals[..., :, None] # (ray #, sample #, 3)
#raw = run_network(pts)
raw = network_query_fn(pts, viewdirs, network_fn)
rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
# hierarchical sampling
if N_importance > 0:
# log outputs of coarse network
rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map
z_vals_mid = .5 * (z_vals[..., 1: ] + z_vals[..., :-1])
z_samples = sample_pdf(z_vals_mid,
weights[..., 1:-1],
N_importance,
det=(perturb==0.), # FALSE by default
pytest=pytest)
z_samples = z_samples.detach()
z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
pts = rays_o[..., None, :] + \
rays_d[..., None, :] * z_vals[..., :, None] # (ray #, coarse & fine sample #, 3)
run_fn = network_fn if network_fine is None \
else network_fine
#raw = run_network(pts, fn=run_fn)
raw = network_query_fn(pts, viewdirs, run_fn)
rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
ret = {'rgb_map' : rgb_map,
'disp_map': disp_map,
'acc_map' : acc_map}
if retraw:
ret['raw'] = raw
if N_importance > 0:
ret['rgb0' ] = rgb_map_0
ret['disp0'] = disp_map_0
ret['acc0' ] = acc_map_0
ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # (ray #)
for k in ret:
if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
print(f"! [Numerical Error] {k} contains nan or inf.")
return ret
Lines to unpack each mini-batch to separate physical values. Ray origins rays_o
, ray directions rays_d
, and viewing directions viewdirs
have shape . Integration boundaries near
and far
are both vectors.
Lines to initialize the samples for ray marching. torch.linspace(…)
at line create a sequence of points evenly scattered along unit length. Recall that rays are previously projected to NDC space. False
by default, lindisp
dictates the voxels are sampled linearly on disparity (inverse depth). z_vals
simply replicates t_vals
(lines and ) to modulate all points on a batch of rays.
There are intervals along each ray. To pick random points out of those intervals ("bins"), at least of them should have length less than . The authors cut the first and last "bin" in half so that all "bins" fit to the interval . Line determines of midpoints of z_vals
, which are afterwards combined with the start (line ) and endpoint bound (line ) of each ray.
Subtracting the lower bound from the upper bound finalizes the length of "bins", and stratified sampling is achieved by uniformly sampling every interval. Now, sample lies in bin .
Let denote the th sample on the th ray in a batch, then z_vals
at line is
Sampling: slight deviation of practice from theory
Implementation of stratified sampling is inconsistent with what is described in the paper. Theoretically, a sample is obtained via
which implies each "bin" is of equal length. The first and last bins, in practice, are half the size of the others. This does not harm the correctness of the algorithm.
Marching depth values z_vals
along directions rays_d
, inputs pts
to network_fn
at lines and are now
The coarse network network_fn
is then queried at line to predict raw output raw
(see appendix for how NeRF is queried). raw
has shape . "Shading" via raw2outputs(…)
follows at line to acquire sample weights and radiance of each ray (see middle tab).
To distinguish outputs of the fine network from those of the coarse one, prefixes _0
are appended to initial outputs at line . Provided z_vals_mid
, midpoints of coarse samples (line ) and their weights , lines to determine fine samples (see right tab). They are combined with coarse samples at line to form a sorted tensor of disparities z_vals
Static compute graph
NeRF spans a static computate graph. The key is the ….detach()
call at line .
The coarse-to-fine passes are connected by hierarchical sampling, i.e., output of the coarse MLP is used to determine the input of the fine network. Bellow illustrates how the coarse samples are processed to for the fine ones:
Hierarchically sampled inputs are again fed to the network.
Consequently, output of the MLP becomes part of its input. It is a must to cut the coarse-to-fine edge in the compute graph because it has to be a directed acyclic one. Otherwise, if the fine network shared with the coarse one an identical instance of class NeRF
, there would be cyclic definition, and backpropagation would fail. This is exactly what z_samples.detach()
does at line .
A corollary is that NeRF's compute graph is static.
New inputs pts
to the fine network network_fine
at lines and are now
Another mass network query is performed at line , whose raw outputs are converted to radiance rgb_map
at .
def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False):
raw2alpha = lambda raw, dists, act_fn=F.relu : \
1. - torch.exp(-act_fn(raw) * dists) # σ column of `raw`
dists = z_vals[..., 1:] - z_vals[..., :-1]
dists = torch.cat([dists, # (ray #, sample #)
torch.Tensor([1e10]).expand(dists[..., :1].shape)],
-1)
dists = dists * torch.norm(rays_d[..., None, :], dim=-1)
rgb = torch.sigmoid(raw[..., :3]) # (ray #, sample #, 3)
noise = 0.
if raw_noise_std > 0.:
noise = torch.randn(raw[..., 3].shape) * raw_noise_std
# overwrite randomly sampled data
if pytest:
np.random.seed(0)
noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std
noise = torch.Tensor(noise)
alpha = raw2alpha(raw[..., 3] + noise, dists) # (ray #, sample #)
#weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)),
1. - alpha + 1e-10], -1),
-1)[:, :-1]
rgb_map = torch.sum(weights[..., None] * rgb, -2) # (ray #, 3)
depth_map = torch.sum(weights * z_vals, -1)
disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map),
depth_map / torch.sum(weights, -1))
acc_map = torch.sum(weights, -1)
if white_bkgd:
rgb_map = rgb_map + (1. - acc_map[..., None])
return rgb_map, disp_map, acc_map, weights, depth_map
dists
from line to calculates the difference between disparities .
Purpose of appending a large vector
is appended to the last column of to (a) maintain the shape of dists
as , and (b) to force the last column of "opacticy" alpha
to be 1 such that classic alpha compositing holds.
Line forces RGB values rgb
to lie in the range , that is,
Random noise (line ) is injected to volume density (line ) before it is rectified and raised to (lines , , and ). Let denote the recitified "opacity" of the th sample along the th ray, then alpha
for alpha compositing are
where denotes the Hadamard product. torch.cumprod(…)
from line to calculates the cumulative transmittance
The last column of is discarded to match the shape of . Rewriting weights (for points' colors) as , weights
is
Recall that radiance is a weighted sum of colors of samples along a ray. This corresponds to line , and the output rgb_map
is
rgb_map
and weights
, along with other values, are returned to render_rays(…)
.
What else are returned?
Content on the way. Stay tuned!
sample_pdf(…)
in run_nerf_helpers.py
performs hierarchical sampling via Monte Carlo method. It is invoked by
…
z_samples = sample_pdf(z_vals_mid,
weights[..., 1:-1],
N_importance,
det=(perturb==0.), # FALSE by default
pytest=pytest)
…
in render_rays(…)
, where z_vals_mid
is a tensor of midpoints of coarse sample disparities. Note that the leading and trailing columns of weights
are excluded from the input such that
def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
# get PDF
weights = weights + 1e-5 # prevent NaN
pdf = weights / torch.sum(weights, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (ray #, bin #)
# Here, `N_samples` refers to `N_importance`.
if det:
u = torch.linspace(0., 1., steps=N_samples)
u = u.expand(list(cdf.shape[:-1]) + [N_samples])
else:
u = torch.rand(list(cdf.shape[ :-1]) + [N_samples])
# if pytest, overwrite u with NumPy fixed random numbers
if pytest:
np.random.seed(0)
new_shape = list(cdf.shape[:-1]) + [N_samples]
if det:
u = np.linspace(0., 1., N_samples)
u = np.broadcast_to(u, new_shape)
else:
u = np.random.rand(*new_shape)
u = torch.Tensor(u)
# invert CDF
u = u.contiguous()
inds = torch.searchsorted(cdf, u, right=True)
below = torch.max(torch.zeros_like(inds-1), inds-1)
above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds)
inds_g = torch.stack([below, above], -1) # (ray #, sample #, 2)
#cdf_g = tf.gather(cdf , inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
#bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
cdf_g = torch.gather( cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
denom = cdf_g[..., 1] - cdf_g[..., 0]
denom = torch.where(denom<1e-5, torch.ones_like(denom),
denom)
t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + \
(bins_g[..., 1] - bins_g[..., 0]) * t
return samples # (ray #, sample #), unsorted along each ray
Line defines the probability that a ray is stopped by a particle at depth . This corresponds to the area under the histogram, shown above (first ray only).
Lines and accumulate the area for the CDF .
What follows is key to Monte Carlo sampling. Line generates a batch of random cumulative probabilities — "seeds". The above figure visualizes operations on u[0, :]
, i.e., seeds on the first ray. They fall into the "bins" at line through comparison against the cumulative probabilities at the boundaries. torch.searchsorted(…)
returns the positions (indices) inds
of those random "seeds".
Quicker indicing
torch.Tensor.contiguous()
returns a tensor with identical data but contiguous in memory. It is called before torch.searchsorted(…)
for performance concern.
Lower bounds of the "bins" are collected as below
at line , and upper bounds are gathered as above
at line . inds_g
combines below
and above
at line . torch.gather(…)
at lines and determine how points along each ray are distributed according to indices inds_g
, or effectively the number of "seeds" in each "bin".
Finally, fine samples are found through similarity, whose concept is illustrated above. Indices 0
correspond to below
, and indices 1
attach to above
. Given cumulative probabilities u
, there holds
Line defines , and line further denotes , then
Hence, lines and determine the unsorted output. Fine samples are expressed in offset from (midpoints of) coarse samples bins_g[..., 0]
.
Optimization
…
for i in trange(start, N_iters):
…
optimizer.zero_grad()
img_loss = img2mse(rgb, target_s)
trans = extras['raw'][...,-1]
loss = img_loss
psnr = mse2psnr(img_loss)
if 'rgb0' in extras:
img_loss0 = img2mse(extras['rgb0'], target_s)
loss = loss + img_loss0
psnr0 = mse2psnr(img_loss0)
loss.backward()
optimizer.step()
# NOTE: IMPORTANT!
### update learning rate ###
decay_rate = 0.1
decay_steps = args.lrate_decay * 1000
new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
for param_group in optimizer.param_groups:
param_group['lr'] = new_lrate
################################
…
Radiance rgb
is compared against the ground truth target_s
to obtain the MSE loss at line . The total loss also includes that of the coarse network (line ). The coarse and fine network are jointly optimized at lines and . Eventually, learning rate decays from line to .
Summary
This post derives the volmue rendering integral and its numrical quadrature. Also explained is its connection with classical alpha compositing. The second part elaborates on the implementation of the rendering pipeline. Illustrations are included to assist understanding procedures such as rays generation and Monte Carlo sampling. Most importantly, the article clearly specifies the physical meaning of each variable and provides the mathematical operation for each statement. To sum, the blog functions as a complete guide for in-depth comprehension of NeRF.
References
Chapter 2.1 in Computer Vision: Algorithms and Applications
Foundamentals of Computer Graphics
NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis
NeRF PyTorch implementation by Yen-Chen Lin
Neural Radiance Field's Volume Rendering 公式分析
Optical Models for Direct Volume Rendering
Part 1 of the SIGGRAPH 2021 course on Advances in Neural Rendering
深度解读yenchenlin/nerf-pytorch项目
Please use the following BibTeX to cite this post:
@misc{yyu2022nerfrendering,
author = {Yu, Yue},
title = {NeRF: A Volume Rendering Perspective},
year = {2022},
howpublished = {\url{https://yconquesty.github.io/blog/ml/nerf/nerf_rendering.html}}
}
Appendix
Content on the way. Stay tuned!
Errata
Time | Modification |
---|---|
Aug 31 2022 | Initial release |
Nov 24 2022 | Rectify reference list |
Dec 2 2022 | Add BibTeX for citation |
Apr 20 2023 | Elaborate on static compute graph |