Skip to content

PTX Backend#18

Open
WillTrojak wants to merge 21 commits into
PyFR:masterfrom
WillTrojak:feature/ptx
Open

PTX Backend#18
WillTrojak wants to merge 21 commits into
PyFR:masterfrom
WillTrojak:feature/ptx

Conversation

@WillTrojak

Copy link
Copy Markdown
Member

This adds a PTX backend to GiMMiK. The key features are:

  • Mild optimisation of exist CUDA algorithms.
  • Optional async loads for some sparse kernels
  • Added dense generation for Hopper and above

Optimisations have focused on FP64, FP32 is future work.

Comment thread gimmik/kernels/ptx/bstream-msplit.mako Outdated
Comment thread gimmik/ptx.py Outdated
Comment thread gimmik/ptx.py Outdated
Comment thread gimmik/ptx.py Outdated
Comment thread gimmik/ptx.py Outdated
Comment thread gimmik/ptx.py Outdated
Comment thread gimmik/ptx.py Outdated
Comment thread gimmik/ptx.py Outdated
Comment thread gimmik/ptx.py Outdated
Comment thread gimmik/ptx.py Outdated
@FreddieWitherden

Copy link
Copy Markdown
Contributor

I know this is an utter pain but for FP32/FP64 can you confirm correctness for all relevant PyFR matrices at a suite of N values for all instances where a kernel is expected to work on A100/H100/B100)?

Comment thread gimmik/kernels/ptx/base.mako Outdated
.param .u64 _c)
{
% endif
.reg .u32 n, id, tid_x, tid_y;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure we throw higher up if n is too big.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking here

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't handle n being too large in any of the other backends.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/PyFR/GiMMiK/blob/master/gimmik/kernels/cuda/cstream.mako#L20 in the embedded case we do (argument case doesn't but that is not currently used for CUDA).

Comment thread gimmik/kernels/ptx/bstream-msplit.mako Outdated
Comment thread gimmik/kernels/ptx/bstream-msplit.mako Outdated
Comment thread gimmik/kernels/ptx/bstream-msplit.mako Outdated
Comment thread gimmik/kernels/ptx/cstream-ksplit.mako Outdated
Comment thread gimmik/kernels/ptx/bstream.mako
Comment thread gimmik/ptx.py Outdated
Comment thread gimmik/ptx.py Outdated
Comment thread gimmik/ptx.py Outdated
Comment thread gimmik/ptx.py Outdated
@FreddieWitherden

Copy link
Copy Markdown
Contributor

JSON looks solid. See if we can factor out some of the common code so that other backends (CUDA) can also use it. Also just makes the code easier to evaluate standalone. I'll start trying to chunk through the kernels, but it would be great if you could give a once sentence sketch of their general approach.

Comment thread gimmik/ptx.py
}

# Map Supported CC -> Minimum PTX version
PTX_SM = {(8, 0): (7, 0), (9, 0): (8, 6), (10, 0): (8, 7), (10, 3): (8, 7),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this okay when new GPUs are released?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, when new GPUs are released, the behaviour will fall back to the default config. This won't give the best performance but it will work.

Comment thread gimmik/ptx.py Outdated
Comment thread gimmik/ptx.py Outdated
Comment thread gimmik/ptx.py Outdated
Comment thread gimmik/ptx.py Outdated
Comment thread gimmik/ptx.py Outdated
Comment thread gimmik/ptx.py Outdated
Comment thread gimmik/ptx.py Outdated
@FreddieWitherden

Copy link
Copy Markdown
Contributor

Does it make sense to move config up a level so it is configs/ptx/ rather than it being under kernels?

## Main loop over B-chunks (double-buffered)
% for bb in range(len(bchunks)):
<%
buf_cur = bb % 2

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check indentation here.

Comment thread gimmik/base.py
pass

def _get_config(self, key):
if key not in self._config_cache:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EAFP

Comment thread gimmik/base.py
# At single precision suffix all floating point constants by 'f'
if dtype == 'float':
# (PTX doesn't use an 'f' suffix for FP literals)
if dtype == 'float' and self.platform != 'ptx':

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have an attr like _needs_fp32_suffix = True|False to avoid the PTX check.

Comment thread gimmik/ptx.py
cfg = [k for k in cfgs if self._usable_config(k, dtype, cc, smem_info)]

for k in cfg:
if prepared := self._get_render_args(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably cleaner not to use walrus here.

Comment thread gimmik/ptx.py
def _sparse_args(self, tpl, params, block, dtype, dsize, args, meta):
blockx = block[0]
args |= {'has_zero_rows': bool(self.has_zero_rows),
'row_nz': [[(kx, self.A[j, kx]) for kx in range(self.k)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Messy; NumPy should help here.

Comment thread gimmik/ptx.py
args |= {'has_zero_rows': bool(self.has_zero_rows),
'row_nz': [[(kx, self.A[j, kx]) for kx in range(self.k)
if self.A[j, kx] != 0] for j in range(self.m)],
'preload_c': bool(params.get('preload_c', False)),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is preload_c not always in params? Try to avoid overly defensive.

Comment thread gimmik/ptx.py
if tpl.startswith('dmma-asmem'):
args |= {
'a_copy_threads': 32 * warps,
'block_stealing': bool(params.get('block_stealing', False)),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, try to avoid being overly defensive.

Comment thread gimmik/ptx.py
tpl = kernel_cfg['template']
nn = params['nn']
warps = params['warps']
tile = kernel_cfg['tile']

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can put some of these definitions onto the same line.

Comment thread gimmik/ptx.py
'b_smem_kgroup_stride': 4 * n_per_cta * args['dwidth_i'],
'b_smem_ntile_stride': setup['tile_n'] * args['dwidth_i'],
'blockx_total': 32 * warps * msplit,
} | offsets

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe merge offsets in in the return statement, so return tpl, args | offsets, meta

Comment thread gimmik/ptx.py
ptx_shape = f'm{tile_m}n{tile_n}k{tile_k}'

m_groups, k_groups = tile_m // 8, tile_k // 4
a_regs = m_groups * k_groups

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No space around * in general.

Comment thread gimmik/ptx.py
return None

if (width == 2
and (self.aligne is None or self.aligne % 2

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Icky identation.

Comment thread gimmik/ptx.py
@staticmethod
def _pred_emit(instr, *preds, pred_reg=None, indent=8 * ' '):
# Handle whether an instruction needs a predicate or not
actual = [p for p in preds if p is not None]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would None get passed in?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants