Skip to content

Commit 3ec4ae6

Browse files
sunqmQiming Sun
and
Qiming Sun
authored
Warp divergent optimization (pyscf#324)
* A different 4c2e algorithm * Modify rys_roots structure * New code generator works * Update rys_jk unrolling * Update rys_roots in various modules * Missing file * Update unrolled_rys * Bugfixes for unrolled_rys * Fix rys_contract_jk_ip1 * Update unrolled_ejk_ip1 * Fix bugs * Optimize unrolled_ejk_ip1 * Optimize shm footprint in rys_contract_jk_ip1 * Improve rys_contract_jk_ip1 and unrolled_ejk_ip1 * Optimize rys_contract_jk_ip2 * Fix unrolled_ejk_ip2 * reduce memory footprint for ip2_type3 * Update rys_contract_jk_ip1 * Fixes * Update unrolled_rys function signature * Fix unrolled_rys_ip1 * update unrolled_rys_ip1 * Update create_tasks * Change rys_roots path * Update j engine * Update pbc/rys_roots_dat.cu * Fix overflow for 48KB shared memory * Fix bug in uhf gradients kernel * Fix unrolled ip1 and ip2 code * Apply rys_roots_rs function for rys_contract_j * Improve logger.init_timer * Remove unused files * Adjust DD_CACHE_MAX size dynamically --------- Co-authored-by: Qiming Sun <[email protected]>
1 parent 45483ee commit 3ec4ae6

36 files changed

+158072
-124530
lines changed

gpu4pyscf/grad/rhf.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,26 @@
3535
LMAX, QUEUE_DEPTH, SHM_SIZE, THREADS, libvhf_rys, _VHFOpt, init_constant,
3636
_make_tril_tile_mappings, _nearest_power2)
3737

38-
libvhf_rys.RYS_per_atom_jk_ip1.restype = ctypes.c_int
39-
4038
__all__ = [
4139
'SCF_GradScanner',
4240
'Gradients',
4341
'Grad'
4442
]
4543

44+
libvhf_rys.RYS_per_atom_jk_ip1.restype = ctypes.c_int
45+
# The max. size of nf*nsq_per_block for each block.
46+
# If shared memory is 48KB, this is enough to cache up to g-type functions,
47+
# corresponding to 15^4 with nsq_per_block=2. All other cases require smaller
48+
# cache for the product of density matrices. Although nsq_per_block would be
49+
# larger, the overall cache requirements is smaller. The following code gives
50+
# the size estimation for each angular momentum pattern (see also
51+
# _ejk_quartets_scheme)
52+
# for li, lj, lk, ll in itertools.product(*[range(LMAX+1)]*4):
53+
# nf = (li+1)*(li+2) * (lj+1)*(lj+2) * (lk+1)*(lk+2) * (ll+1)*(ll+2) // 16
54+
# g_size = (li+2)*(lj+1)*(lk+2)*(ll+1)
55+
# dd_cache_size = nf * min(THREADS, _nearest_power2(SHM_SIZE//(g_size*3*8)))
56+
DD_CACHE_MAX = 101250 * (SHM_SIZE//48000)
57+
4658
def _ejk_ip1_task(mol, dms, vhfopt, task_list, j_factor=1.0, k_factor=1.0,
4759
device_id=0, verbose=0):
4860
n_dm = dms.shape[0]
@@ -77,6 +89,7 @@ def _ejk_ip1_task(mol, dms, vhfopt, task_list, j_factor=1.0, k_factor=1.0,
7789
log_cutoff-log_max_dm)
7890
workers = gpu_specs['multiProcessorCount']
7991
pool = cp.empty((workers, QUEUE_DEPTH*4), dtype=np.uint16)
92+
dd_pool = cp.empty((workers, DD_CACHE_MAX), dtype=np.float64)
8093
info = cp.empty(2, dtype=np.uint32)
8194
t1 = log.timer_debug1(f'q_cond and dm_cond on Device {device_id}', *cput0)
8295

@@ -104,6 +117,7 @@ def _ejk_ip1_task(mol, dms, vhfopt, task_list, j_factor=1.0, k_factor=1.0,
104117
ctypes.cast(dm_cond.data.ptr, ctypes.c_void_p),
105118
ctypes.c_float(log_cutoff),
106119
ctypes.cast(pool.data.ptr, ctypes.c_void_p),
120+
ctypes.cast(dd_pool.data.ptr, ctypes.c_void_p),
107121
ctypes.cast(info.data.ptr, ctypes.c_void_p),
108122
ctypes.c_int(workers),
109123
mol._atm.ctypes, ctypes.c_int(mol.natm),
@@ -193,11 +207,11 @@ def _ejk_quartets_scheme(mol, l_ctr_pattern, shm_size=SHM_SIZE):
193207
ls = l_ctr_pattern[:,0]
194208
li, lj, lk, ll = ls
195209
order = li + lj + lk + ll
196-
g_size = (li+2)*(lj+2)*(lk+2)*(ll+2)
210+
g_size = (li+2)*(lj+1)*(lk+2)*(ll+1)
197211
nps = l_ctr_pattern[:,1]
198212
ij_prims = nps[0] * nps[1]
199213
nroots = (order + 1) // 2 + 1
200-
unit = nroots*2 + g_size*3 + ij_prims*4
214+
unit = nroots*2 + g_size*3 + ij_prims + 9
201215
if mol.omega < 0: # SR
202216
unit += nroots * 2
203217
counts = shm_size // (unit*8)

gpu4pyscf/grad/tests/test_rhf_grad.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import pyscf
1615
import numpy as np
16+
import cupy as cp
1717
import unittest
1818
import pytest
19+
import pyscf
20+
from pyscf import lib
1921
from pyscf import scf as cpu_scf
2022
from gpu4pyscf import scf as gpu_scf
23+
from pyscf.grad import rhf as rhf_grad_cpu
24+
from gpu4pyscf.grad import rhf as rhf_grad_gpu
2125
from packaging import version
2226

2327
atom = '''
@@ -85,6 +89,32 @@ def test_to_cpu(self):
8589
g_cpu = cpu_gradient.kernel()
8690
assert np.linalg.norm(g_gpu - g_cpu) < 1e-5
8791

92+
def test_jk_energy_per_atom(self):
93+
mol = pyscf.M(
94+
atom = '''
95+
O 0.000 -0. 0.1174
96+
H -0.757 4. -0.4696
97+
H 0.757 4. -0.4696
98+
C 3. 1. 0.
99+
''',
100+
basis='def2-tzvp',
101+
unit='B',)
102+
np.random.seed(9)
103+
nao = mol.nao
104+
dm = np.random.rand(nao, nao) - .5
105+
dm = cp.asarray(dm.dot(dm.T))
106+
ejk = rhf_grad_gpu._jk_energy_per_atom(mol, dm).get()
107+
self.assertAlmostEqual(ejk.sum(), 0, 9)
108+
self.assertAlmostEqual(lib.fp(ejk), 2710.490337642, 9)
109+
110+
dm = dm.get()
111+
vj, vk = rhf_grad_cpu.get_jk(mol, dm)
112+
veff = vj - vk * .5
113+
ref = np.empty_like(ejk)
114+
for n, (i0, i1) in enumerate(mol.aoslice_by_atom()[:,2:]):
115+
ref[n] = np.einsum('xpq,pq->x', veff[:,i0:i1], dm[i0:i1])
116+
self.assertAlmostEqual(abs(ejk - ref).max(), 0, 9)
117+
88118
if __name__ == "__main__":
89119
print("Full Tests for RHF Gradient")
90120
unittest.main()

gpu4pyscf/grad/tests/test_uhf_grad.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,11 @@
3131

3232
def setUpModule():
3333
global mol_sph, mol_cart
34-
mol_sph = pyscf.M(atom=atom, basis=bas0, max_memory=32000)
35-
mol_sph.output = '/dev/null'
36-
mol_sph.build()
37-
mol_sph.verbose = 1
34+
mol_sph = pyscf.M(atom=atom, basis=bas0, max_memory=32000,
35+
output='/dev/null', verbose=1)
3836

39-
mol_cart = pyscf.M(atom=atom, basis=bas0, max_memory=32000, cart=1)
40-
mol_cart.output = '/dev/null'
41-
mol_cart.build()
42-
mol_cart.verbose = 1
37+
mol_cart = pyscf.M(atom=atom, basis=bas0, max_memory=32000, cart=1, spin=2,
38+
output='/dev/null', verbose=1)
4339

4440
def tearDownModule():
4541
global mol_sph, mol_cart
@@ -64,11 +60,11 @@ def _check_grad(mol, tol=1e-6, disp=None):
6460
class KnownValues(unittest.TestCase):
6561
def test_grad_uhf(self):
6662
print('---- testing UHF -------')
67-
_check_grad(mol_sph, tol=1e-6)
63+
_check_grad(mol_sph, tol=1e-10)
6864

6965
def test_grad_cart(self):
7066
print('---- testing UHF Cart -------')
71-
_check_grad(mol_cart, tol=1e-6)
67+
_check_grad(mol_cart, tol=1e-10)
7268

7369
@pytest.mark.skipif(pyscf_25, reason='requires pyscf 2.6 or higher')
7470
def test_grad_d3bj(self):

gpu4pyscf/grad/tests/test_uks_grad.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def setUpModule():
3333
mol_sph = pyscf.M(atom=atom, basis=bas0, max_memory=32000,
3434
output='/dev/null', verbose=1)
3535

36-
mol_cart = pyscf.M(atom=atom, basis=bas0, max_memory=32000, cart=1,
36+
mol_cart = pyscf.M(atom=atom, basis=bas0, max_memory=32000, cart=1, spin=2,
3737
output='/dev/null', verbose=1)
3838

3939
def tearDownModule():
@@ -65,39 +65,39 @@ class KnownValues(unittest.TestCase):
6565

6666
def test_grad_with_grids_response(self):
6767
print("-----testing unrestricted DFT gradient with grids response----")
68-
_check_grad(mol_sph, grid_response=True)
68+
_check_grad(mol_sph, grid_response=True, tol=1e-10)
6969

7070
def test_grad_without_grids_response(self):
7171
print('-----testing unrestricted DFT gradient without grids response----')
72-
_check_grad(mol_sph, grid_response=False)
72+
_check_grad(mol_sph, grid_response=False, tol=1e-10)
7373

7474
def test_grad_lda(self):
7575
print("-----LDA testing-------")
76-
_check_grad(mol_sph, xc='LDA', disp=None)
76+
_check_grad(mol_sph, xc='LDA', disp=None, tol=1e-10)
7777

7878
def test_grad_gga(self):
7979
print('-----GGA testing-------')
80-
_check_grad(mol_sph, xc='PBE', disp=None)
80+
_check_grad(mol_sph, xc='PBE', disp=None, tol=1e-10)
8181

8282
def test_grad_hybrid(self):
8383
print('------hybrid GGA testing--------')
84-
_check_grad(mol_sph, xc='B3LYP', disp=None)
84+
_check_grad(mol_sph, xc='B3LYP', disp=None, tol=1e-10)
8585

8686
def test_grad_mgga(self):
8787
print('-------mGGA testing-------------')
88-
_check_grad(mol_sph, xc='tpss', disp=None)
88+
_check_grad(mol_sph, xc='tpss', disp=None, tol=1e-10)
8989

9090
def test_grad_rsh(self):
9191
print('--------RSH testing-------------')
92-
_check_grad(mol_sph, xc='wb97', disp=None)
92+
_check_grad(mol_sph, xc='wb97', disp=None, tol=1e-10)
9393

9494
def test_grad_nlc(self):
9595
print('--------nlc testing-------------')
9696
_check_grad(mol_sph, xc='HYB_MGGA_XC_WB97M_V', disp=None)
9797

9898
def test_grad_cart(self):
9999
print('------hybrid GGA Cart testing--------')
100-
_check_grad(mol_cart, xc='B3LYP', disp=None)
100+
_check_grad(mol_cart, xc='B3LYP', disp=None, tol=1e-10)
101101

102102
def test_grad_d3bj(self):
103103
print('------hybrid GGA with D3(BJ) testing--------')

gpu4pyscf/hessian/rhf.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646

4747
GB = 1024*1024*1024
4848
ALIGNED = 4
49+
DD_CACHE_MAX = rhf_grad.DD_CACHE_MAX
4950

5051
def hess_elec(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
5152
mo1=None, mo_e1=None, h1mo=None,
@@ -201,6 +202,7 @@ def _ejk_ip2_task(mol, dms, vhfopt, task_list, j_factor=1.0, k_factor=1.0,
201202
log_cutoff-log_max_dm)
202203
workers = gpu_specs['multiProcessorCount']
203204
pool = cp.empty((workers, QUEUE_DEPTH*4), dtype=np.uint16)
205+
dd_pool = cp.empty((workers, DD_CACHE_MAX), dtype=np.float64)
204206
info = cp.empty(2, dtype=np.uint32)
205207
t1 = log.timer_debug1(f'q_cond and dm_cond on Device {device_id}', *cput0)
206208

@@ -228,10 +230,13 @@ def _ejk_ip2_task(mol, dms, vhfopt, task_list, j_factor=1.0, k_factor=1.0,
228230
ctypes.cast(dm_cond.data.ptr, ctypes.c_void_p),
229231
ctypes.c_float(log_cutoff),
230232
ctypes.cast(pool.data.ptr, ctypes.c_void_p),
233+
ctypes.cast(dd_pool.data.ptr, ctypes.c_void_p),
231234
ctypes.cast(info.data.ptr, ctypes.c_void_p),
232235
ctypes.c_int(workers),
233236
mol._atm.ctypes, ctypes.c_int(mol.natm),
234237
mol._bas.ctypes, ctypes.c_int(mol.nbas), mol._env.ctypes)
238+
239+
scheme = _ip2_type3_quartets_scheme(mol, uniq_l_ctr[[i, j, k, l]])
235240
err2 = kern2(
236241
ctypes.cast(ejk.data.ptr, ctypes.c_void_p),
237242
ctypes.c_double(j_factor), ctypes.c_double(k_factor),
@@ -247,10 +252,12 @@ def _ejk_ip2_task(mol, dms, vhfopt, task_list, j_factor=1.0, k_factor=1.0,
247252
ctypes.cast(dm_cond.data.ptr, ctypes.c_void_p),
248253
ctypes.c_float(log_cutoff),
249254
ctypes.cast(pool.data.ptr, ctypes.c_void_p),
255+
ctypes.cast(dd_pool.data.ptr, ctypes.c_void_p),
250256
ctypes.cast(info.data.ptr, ctypes.c_void_p),
251257
ctypes.c_int(workers),
252258
mol._atm.ctypes, ctypes.c_int(mol.natm),
253259
mol._bas.ctypes, ctypes.c_int(mol.nbas), mol._env.ctypes)
260+
254261
if err1 != 0 or err2 != 0:
255262
raise RuntimeError(f'RYS_per_atom_jk_ip2 kernel for {llll} failed')
256263
if log.verbose >= logger.DEBUG1:
@@ -345,7 +352,23 @@ def _ip2_quartets_scheme(mol, l_ctr_pattern, shm_size=SHM_SIZE):
345352
nps = l_ctr_pattern[:,1]
346353
ij_prims = nps[0] * nps[1]
347354
nroots = (order + 2) // 2 + 1
348-
unit = nroots*2 + g_size*3 + ij_prims*4
355+
unit = nroots*2 + g_size*3 + ij_prims + 9
356+
if mol.omega < 0: # SR
357+
unit += nroots * 2
358+
counts = shm_size // (unit*8)
359+
n = min(THREADS, _nearest_power2(counts))
360+
gout_stride = THREADS // n
361+
return n, gout_stride
362+
363+
def _ip2_type3_quartets_scheme(mol, l_ctr_pattern, shm_size=SHM_SIZE):
364+
ls = l_ctr_pattern[:,0]
365+
li, lj, lk, ll = ls
366+
order = li + lj + lk + ll
367+
g_size = (li+2)*(lj+1)*(lk+2)*(ll+1)
368+
nps = l_ctr_pattern[:,1]
369+
ij_prims = nps[0] * nps[1]
370+
nroots = (order + 2) // 2 + 1
371+
unit = nroots*2 + g_size*3 + ij_prims + 9
349372
if mol.omega < 0: # SR
350373
unit += nroots * 2
351374
counts = shm_size // (unit*8)
@@ -590,9 +613,8 @@ def _ip1_quartets_scheme(mol, l_ctr_pattern, shm_size=SHM_SIZE):
590613
ij_prims = nps[0] * nps[1]
591614
nroots = (order + 1) // 2 + 1
592615

593-
unit = nroots*2 + g_size*3
594-
shm_size -= ij_prims*12 * 8
595-
counts = shm_size // (unit*8)
616+
unit = nroots*2 + g_size*3 + 6
617+
counts = (shm_size - ij_prims*6 * 8) // (unit*8)
596618
n = min(THREADS, _nearest_power2(counts))
597619
gout_stride = THREADS // n
598620
gout_width = 18

gpu4pyscf/hessian/tests/test_rhf_hessian.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_hessian_rhf(self):
5151

5252
def test_partial_hess_elec(self):
5353
mf = pyscf.scf.RHF(mol)
54-
mf.conv_tol = 1e-14
54+
mf.conv_tol = 1e-12
5555
mf.kernel()
5656
hobj = mf.Hessian()
5757
e1_cpu, ej_cpu, ek_cpu = rhf_cpu._partial_hess_ejk(hobj)
@@ -62,8 +62,8 @@ def test_partial_hess_elec(self):
6262
hobj = mf.Hessian()
6363
e1_gpu, e2_gpu = rhf_gpu._partial_hess_ejk(hobj)
6464

65-
assert abs(e1_cpu - e1_gpu.get()).max() < 1e-5
66-
assert abs(e2_cpu - e2_gpu.get()).max() < 1e-5
65+
assert abs(e1_cpu - e1_gpu.get()).max() < 1e-7
66+
assert abs(e2_cpu - e2_gpu.get()).max() < 1e-7
6767

6868
def test_ejk_ip2(self):
6969
mol = gto.M(
@@ -76,20 +76,21 @@ def test_ejk_ip2(self):
7676
basis='6-31g**', unit='B')
7777
np.random.seed(9)
7878
nao = mol.nao
79-
mo_coeff = np.random.rand(nao, nao)
79+
mo_coeff = np.random.rand(nao, nao) - .5
8080
dm = mo_coeff.dot(mo_coeff.T) * 2
8181
mo_occ = np.ones(nao) * 2
8282
mo_energy = np.random.rand(nao)
8383

8484
ejk = rhf_gpu._partial_ejk_ip2(mol, dm)
85+
assert abs(lib.fp(ejk.get()) - 1116.6336092900506) < 1e-8
8586
mf = mol.RHF()
8687
mf.mo_coeff = mo_coeff
8788
mf.mo_occ = mo_occ
8889
mf.mo_energy = mo_energy
8990
h = rhf_cpu.Hessian(mf)
9091
e1, refj, refk = rhf_cpu._partial_hess_ejk(h, mo_energy, mo_coeff, mo_occ)
9192
e2_ref = refj - refk
92-
assert abs(ejk.get() - e2_ref).max() < 1e-6
93+
assert abs(ejk.get() - e2_ref).max() < 1e-8
9394

9495
def test_get_jk(self):
9596
mol = gto.M(

gpu4pyscf/hessian/tests/test_uhf_hessian.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def test_hessian_uhf(self):
5353

5454
def test_partial_hess_elec(self):
5555
mf = pyscf.scf.UHF(mol)
56-
mf.conv_tol = 1e-14
56+
mf.conv_tol = 1e-12
5757
mf.kernel()
5858
hobj = mf.Hessian()
5959
e1_cpu, ej_cpu, ek_cpu = uhf_cpu._partial_hess_ejk(hobj)
@@ -64,8 +64,8 @@ def test_partial_hess_elec(self):
6464
hobj = mf.Hessian()
6565
e1_gpu, e2_gpu = uhf_gpu._partial_hess_ejk(hobj)
6666

67-
assert numpy.linalg.norm(e1_cpu - e1_gpu.get()) < 1e-5
68-
assert numpy.linalg.norm(e2_cpu - e2_gpu.get()) < 1e-5
67+
assert numpy.linalg.norm(e1_cpu - e1_gpu.get()) < 1e-7
68+
assert numpy.linalg.norm(e2_cpu - e2_gpu.get()) < 1e-7
6969

7070
def test_hessian_uhf_D3(self):
7171
print('----- testing UHF with D3BJ ------')

gpu4pyscf/lib/gint/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ add_library(gint SHARED
3535
nr_fill_ao_int3c2e_ip1ip2.cu
3636
nr_fill_ao_int3c2e_ipvip1.cu
3737
j_engine_matrix_reorder.c
38+
rys_roots_dat.cu
3839
)
3940

4041
#option(BUILD_SHARED_LIBS "build shared libraries" 1)

gpu4pyscf/lib/gint/rys_roots_dat.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#include "gvhf-rys/rys_roots_dat.cu"

gpu4pyscf/lib/gvhf-rys/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --ptxas-options=-v")# -maxrregcount=128")
22

33
add_library(gvhf_rys SHARED
4-
rys_contract_jk.cu rys_jk_driver.cu unrolled_os.cu unrolled_rys.cu
4+
rys_contract_jk.cu rys_jk_driver.cu rys_roots_dat.cu
5+
unrolled_os.cu unrolled_rys.cu
56
nr_sr_estimator.c
67
rys_contract_j.cu cart2xyz.c unrolled_rys_j.cu
78
count_tasks.cu

0 commit comments

Comments
 (0)