NetKet Model Issue With Advanced Driver: A Fix Guide

by SLV Team 53 views

Introduction

Hey everyone! I'm facing a bit of a puzzle integrating a NetKet model with an advanced driver, and I thought I'd share my journey and hopefully get some insights from you all. I've got a NetKet model that's working smoothly with the standard NetKet drivers, but when I try to use it with an advanced driver, I'm hitting a snag. Specifically, I'm encountering a TypeError: iteration over a 0-d array. Let's dive into the details and see if we can figure this out together!

The Setup: My NetKet Model

First, let's take a look at the NetKet model I'm working with. This model is based on a NetKet sample and involves a LogNeuralBackflowSpinful neural network designed for fermionic systems. Here’s the code snippet that defines my model:

import netket as nk
import jax
import jax.numpy as jnp
from flax import nnx
import optax
import advanced_drivers as advd

parameter_dtype = jnp.complex64  # we need complex numbers for fermions

pbc = [False, True]
n_hidden = 64
n_steps = 200
n_samples = 32
machine_pow = 1.0

chunk_size = 512  # to avoid OOM
tiny_init_std = 5e-3  # was 1e-3 in previous tests
lr = 0.01
exp_name = "HubbardBackflowRect_PBC"


L_full = 8
N_up = N_down = 14


hilb8 = nk.hilbert.SpinOrbitalFermions(
    n_orbitals=4*L_full, s=1/2, n_fermions_per_spin=(N_up, N_down)
)

hi = hilb8

t = 1
U = 8


def c(site, sz):
    return nk.operator.fermion.destroy(hi, site, sz=sz)


def cdag(site, sz):
    return nk.operator.fermion.create(hi, site, sz=sz)


def nc(site, sz):
    return nk.operator.fermion.number(hi, site, sz=sz)


g = nk.graph.Grid(extent=[L_full, 4], pbc=pbc)
g8 = g
up = +1
down = -1
ham = 0.0

for sz in (up, down):
    for u, v in g.edges():
        # Map tile-local indices to global indices
        ham += -t * cdag(u, sz) * c(v, sz) - t * cdag(v, sz) * c(u, sz)
for u in g.nodes():
    # Only add Hubbard U term for sites inside tile 0
    ham += U * nc(u, up) * nc(u, down)


H8 = ham  # nkx.operator.ParticleNumberAndSpinConservingFermioperator2nd.from_fermionoperator2nd(ham)

def _logdet_cmplx(A, eps=1e-6):
    # Stable logdet
    eps_eye = eps * jnp.eye(A.shape[0], dtype=A.dtype)
    sign, logabsdet = jnp.linalg.slogdet(A + eps_eye)
    # jax.debug.print("slogdet sign: {s}, logabsdet: {l}", s=sign, l=logabsdet)
    if parameter_dtype == jnp.float32:
        return logabsdet  # .astype(complex) + jnp.log(sign.astype(complex))
    else:
        return logabsdet.astype(complex) + jnp.log(sign.astype(complex))


class LogNeuralBackflowSpinful(nnx.Module):
    hilbert: "nk.hilbert.SpinOrbitalFermions"

    def __init__(
        self,
        hilbert,
        hidden_units: int,
        kernel_init=nnx.initializers.lecun_normal(),
        param_dtype=parameter_dtype,
        *,
        rngs: nnx.Rngs,
    ):
        self.hilbert = hilbert
        L = hilbert.n_orbitals
        Nup, Ndown = hilbert.n_fermions_per_spin

        key_up, key_down, key_bf = jax.random.split(rngs.params(), 3)

        orthogonal_init = nnx.initializers.orthogonal()

        self.M_up = nnx.Param(orthogonal_init(key_up, (L, Nup), param_dtype))
        self.M_down = nnx.Param(orthogonal_init(key_down, (L, Ndown), param_dtype))

        tiny_init = nnx.initializers.normal(stddev=tiny_init_std)

        # Backflow network: input = spinful config (2*L,), output = correction for both orbitals
        self.backflow = nnx.Sequential(
            nnx.Linear(
                in_features=2 * L,
                out_features=hidden_units,
                param_dtype=param_dtype,
                rngs=nnx.Rngs(key_bf),
                kernel_init=tiny_init,
            ),
            nnx.tanh,
            nnx.Linear(
                in_features=hidden_units,
                out_features=L * (Nup + Ndown),
                param_dtype=param_dtype,
                rngs=nnx.Rngs(key_bf),
                kernel_init=tiny_init,
                use_bias=False,
            ),
            lambda x: x.reshape(
                x.shape[:-1] + (L, Nup + Ndown)
            ),  # (..., L, Nup+Ndown)
        )

    def __call__(self, n: jax.Array) -> jax.Array:
        def log_sd(ncfg):
            L = self.hilbert.n_orbitals
            Nup, Ndown = self.hilbert.n_fermions_per_spin

            n_up, n_down = ncfg[:L], ncfg[L:]
            R_up = n_up.nonzero(size=Nup)[0]
            R_down = n_down.nonzero(size=Ndown)[0]

            # Backflow correction
            F = self.backflow(ncfg)  # shape (L, Nup+Ndown)
            F_up, F_down = F[:, :Nup], F[:, Nup:]

            M_up = self.M_up + F_up
            M_down = self.M_down + F_down

            A_up = M_up[R_up]
            A_down = M_down[R_down]

            ret = _logdet_cmplx(A_up) + _logdet_cmplx(A_down)
            return ret
        
        if n.ndim == 1:
            return log_sd(n)
        return jax.vmap(log_sd)(n)

  
rngs = nnx.Rngs(jax.random.PRNGKey(0))
ma8 = LogNeuralBackflowSpinful(hilb8, rngs=rngs, hidden_units=n_hidden)
sa8 = nk.sampler.MetropolisFermionHop(hilb8, graph=g8,
                                      n_chains=64, sweep_size=256, machine_pow=machine_pow)
vs8 = nk.vqs.MCState(sa8, ma8, n_samples=n_samples)
vs8.chunk_size = chunk_size  # to avoid OOM

This code defines a LogNeuralBackflowSpinful module using nnx, which is part of the Flax library in JAX. It represents a neural network Ansatz for a fermionic system, incorporating backflow transformations. The model calculates the logarithm of the wave function, which is crucial for variational Monte Carlo (VMC) simulations.

The Problem: TypeError with Advanced Driver

The issue arises when I try to use this model with NetKet's advanced driver. Here’s the code snippet that triggers the error:

n_iter = 100
diag_shift = optax.linear_schedule(init_value= 0.01,
                                transition_steps= 200,
                                end_value= 1e-4) 
modulus_distribution = advd.driver.overdispersed_distribution(alpha=1.0)
opt = optax.sgd(learning_rate=lr)
vstate = vs8
vstate.chunk_size = 1024
driver =advd.driver.VMC_NG(hamiltonian=H8, 
                            optimizer=opt, 
                            # sampling_distribution = modulus_distribution, #here we specified the new sampling distribution
                            variational_state=vstate, 
                            diag_shift=diag_shift, 
                            # use_ntk=False,
                            # on_the_fly=False
                          )

#out_psi_mod = nk.logging.RuntimeLog()
#fs_state_err = nk.vqs.FullSumState(hilbert = vstate.hilbert, 
#                                    model = vstate.model,  
#                                    seed=0)
driver.run(n_iter = n_iter) # , out = out_psi_mod) # , callback = (InvalidLossStopping(), save_exact_err(fs_state_err, E_gs, save_every=10)))

When I run this code, I get the following error:

-> [3708](https://file+.vscode-resource.vscode-cdn.net/home/detlef/Nextcloud/SCHULE/gutzwille_chatGPT/importance_sampling_nqs/importance_sampling_nqs/notebooks/~/Nextcloud/SCHULE/gutzwille_chatGPT/importance_sampling_nqs/importance_sampling_nqs/.venv/lib/python3.13/site-packages/jax/_src/lax/lax.py:3708)     raise TypeError("iteration over a 0-d array")  # same as numpy error
   3709   else:
   3710     n = int(tracer.shape[0])

TypeError: iteration over a 0-d array

This TypeError indicates that there’s an attempt to iterate over a scalar (0-dimensional array) where an iterable is expected. This often happens in JAX when there’s a mismatch in the expected input dimensions within a compiled function.

The Good News: It Works with Standard NetKet Driver

The interesting part is that the same model works perfectly fine with the standard NetKet driver. Here’s the code that runs without issues:

print("hidden units:", n_hidden)
opt = nk.optimizer.Sgd(learning_rate=lr)
sr = nk.optimizer.SR(diag_shift=0.1, holomorphic=True)
gs8 = nk.driver.VMC(H8, opt, variational_state=vs8, preconditioner=sr)

gs8.run(n_steps, out=exp_name)

This tells us that the core model definition and the variational Monte Carlo setup are likely correct. The problem seems to be specific to how the advanced driver handles the model or its inputs.

Diagnosing the TypeError: My Investigation

So, what’s causing this TypeError when using the advanced driver? Let's break down the potential culprits and how I'm trying to tackle them.

Understanding the Error Context

The error message points to a line in jax/_src/lax/lax.py, which suggests the issue is within JAX's loop handling or array processing. Since the standard NetKet driver works, the problem likely lies in how advanced_drivers interacts with JAX transformations (like jax.vmap or jax.jit) or how it handles array shapes.

Potential Causes and My Attempts to Resolve

  1. Shape Mismatches: The most common cause of this error is an unexpected shape in an array being passed to a JAX function.

    • My Approach: I've been carefully reviewing the input shapes to the VMC_NG driver, especially within the loss function and gradient computation. I'm making sure that the arrays passed to JAX functions have the expected dimensions.
  2. Incorrect Use of JAX Transformations: Functions like jax.vmap and jax.jit can sometimes lead to unexpected behavior if not used correctly.

    • My Approach: I'm examining the advanced_drivers code to see how these transformations are applied, particularly in the context of the variational state and the Hamiltonian.
  3. Incompatible Data Types: JAX is very strict about data types. A mismatch between expected and actual data types can cause issues.

    • My Approach: I’m ensuring that the parameter_dtype is consistent throughout the model and the driver. I’m using jnp.complex64 for complex numbers, which is necessary for fermionic systems.

Diving Deep into the VMC_NG Driver

To really get to the bottom of this, I need to understand what the VMC_NG driver is doing differently from the standard nk.driver.VMC. Here are some key areas I'm focusing on:

  • Sampling Distribution: The advanced driver allows for custom sampling distributions. I've tried using overdispersed_distribution, but the error persists even without specifying a custom distribution. This suggests the issue isn't solely with the sampling distribution itself.
  • Optimizer and Learning Rate Schedule: The VMC_NG driver uses optax optimizers, which are different from the standard NetKet optimizers. I’m using optax.sgd with a linear learning rate schedule (optax.linear_schedule). I'll need to ensure that this schedule is behaving as expected and not producing any unexpected values.
  • Gradient Computation: The way gradients are computed might be different in the advanced driver. I’ll be looking at how the loss function is defined and how gradients are calculated using JAX.

Steps I’ve Taken So Far

  • Verified Input Shapes: I’ve added print statements and JAX debugging tools to check the shapes of arrays at various points in the code. This has helped me confirm that the input to the model has the expected shape.
  • Simplified the Learning Rate Schedule: I’ve tried using a constant learning rate instead of a schedule to rule out any issues with the schedule itself.
  • Checked Data Types: I’ve made sure that the data types are consistent, especially for complex numbers.
  • Reviewed JAX Transformations: I’ve examined the use of jax.vmap and jax.jit in the advanced_drivers code to see if there are any potential issues.

Current Status and Next Steps

As of now, I'm still stuck on this TypeError. I've narrowed down the potential causes, but I haven't found the exact root of the problem. Here’s what I plan to do next:

  1. Minimal Reproducible Example: I’m going to try to create a minimal, self-contained example that reproduces the error. This will help isolate the issue and make it easier to debug.
  2. Step-by-Step Debugging: I’ll use JAX’s debugging tools (like jax.debug.print) to step through the code and inspect the values of variables at each step. This should help me pinpoint exactly where the error occurs.
  3. Consult the NetKet Community: I’ll reach out to the NetKet community for help. There might be others who have encountered similar issues, or someone might spot something I’ve missed.
  4. Examine advanced_drivers Code More Closely: I’ll continue to dive into the advanced_drivers code, paying close attention to how it handles gradients and optimization.

Call for Help: Let's Solve This Together!

I’m sharing this in the hopes that some of you might have insights or suggestions that could help me resolve this issue. Have you encountered a similar TypeError in JAX or NetKet? Do you have any tips for debugging JAX code, especially when dealing with advanced drivers or custom sampling distributions?

If you have any ideas or suggestions, please feel free to share them in the comments below. Let's work together to get this NetKet model running smoothly with the advanced driver!

Thanks for reading, and I look forward to hearing your thoughts! Let's crack this nut together! 🚀