PTX Backend#18
Conversation
|
I know this is an utter pain but for FP32/FP64 can you confirm correctness for all relevant PyFR matrices at a suite of N values for all instances where a kernel is expected to work on A100/H100/B100)? |
| .param .u64 _c) | ||
| { | ||
| % endif | ||
| .reg .u32 n, id, tid_x, tid_y; |
There was a problem hiding this comment.
Ensure we throw higher up if n is too big.
There was a problem hiding this comment.
We don't handle n being too large in any of the other backends.
There was a problem hiding this comment.
https://github.com/PyFR/GiMMiK/blob/master/gimmik/kernels/cuda/cstream.mako#L20 in the embedded case we do (argument case doesn't but that is not currently used for CUDA).
|
JSON looks solid. See if we can factor out some of the common code so that other backends (CUDA) can also use it. Also just makes the code easier to evaluate standalone. I'll start trying to chunk through the kernels, but it would be great if you could give a once sentence sketch of their general approach. |
| } | ||
|
|
||
| # Map Supported CC -> Minimum PTX version | ||
| PTX_SM = {(8, 0): (7, 0), (9, 0): (8, 6), (10, 0): (8, 7), (10, 3): (8, 7), |
There was a problem hiding this comment.
Is this okay when new GPUs are released?
There was a problem hiding this comment.
Yeah, when new GPUs are released, the behaviour will fall back to the default config. This won't give the best performance but it will work.
|
Does it make sense to move config up a level so it is configs/ptx/ rather than it being under kernels? |
| ## Main loop over B-chunks (double-buffered) | ||
| % for bb in range(len(bchunks)): | ||
| <% | ||
| buf_cur = bb % 2 |
There was a problem hiding this comment.
Check indentation here.
| pass | ||
|
|
||
| def _get_config(self, key): | ||
| if key not in self._config_cache: |
| # At single precision suffix all floating point constants by 'f' | ||
| if dtype == 'float': | ||
| # (PTX doesn't use an 'f' suffix for FP literals) | ||
| if dtype == 'float' and self.platform != 'ptx': |
There was a problem hiding this comment.
Have an attr like _needs_fp32_suffix = True|False to avoid the PTX check.
| cfg = [k for k in cfgs if self._usable_config(k, dtype, cc, smem_info)] | ||
|
|
||
| for k in cfg: | ||
| if prepared := self._get_render_args( |
There was a problem hiding this comment.
Probably cleaner not to use walrus here.
| def _sparse_args(self, tpl, params, block, dtype, dsize, args, meta): | ||
| blockx = block[0] | ||
| args |= {'has_zero_rows': bool(self.has_zero_rows), | ||
| 'row_nz': [[(kx, self.A[j, kx]) for kx in range(self.k) |
There was a problem hiding this comment.
Messy; NumPy should help here.
| args |= {'has_zero_rows': bool(self.has_zero_rows), | ||
| 'row_nz': [[(kx, self.A[j, kx]) for kx in range(self.k) | ||
| if self.A[j, kx] != 0] for j in range(self.m)], | ||
| 'preload_c': bool(params.get('preload_c', False)), |
There was a problem hiding this comment.
Is preload_c not always in params? Try to avoid overly defensive.
| if tpl.startswith('dmma-asmem'): | ||
| args |= { | ||
| 'a_copy_threads': 32 * warps, | ||
| 'block_stealing': bool(params.get('block_stealing', False)), |
There was a problem hiding this comment.
Same here, try to avoid being overly defensive.
| tpl = kernel_cfg['template'] | ||
| nn = params['nn'] | ||
| warps = params['warps'] | ||
| tile = kernel_cfg['tile'] |
There was a problem hiding this comment.
Can put some of these definitions onto the same line.
| 'b_smem_kgroup_stride': 4 * n_per_cta * args['dwidth_i'], | ||
| 'b_smem_ntile_stride': setup['tile_n'] * args['dwidth_i'], | ||
| 'blockx_total': 32 * warps * msplit, | ||
| } | offsets |
There was a problem hiding this comment.
Maybe merge offsets in in the return statement, so return tpl, args | offsets, meta
| ptx_shape = f'm{tile_m}n{tile_n}k{tile_k}' | ||
|
|
||
| m_groups, k_groups = tile_m // 8, tile_k // 4 | ||
| a_regs = m_groups * k_groups |
There was a problem hiding this comment.
No space around * in general.
| return None | ||
|
|
||
| if (width == 2 | ||
| and (self.aligne is None or self.aligne % 2 |
| @staticmethod | ||
| def _pred_emit(instr, *preds, pred_reg=None, indent=8 * ' '): | ||
| # Handle whether an instruction needs a predicate or not | ||
| actual = [p for p in preds if p is not None] |
There was a problem hiding this comment.
Why would None get passed in?
This adds a PTX backend to GiMMiK. The key features are:
Optimisations have focused on FP64, FP32 is future work.