whiterose

linux unikernel
Log | Files | Refs | README | LICENSE | git clone https://git.ne02ptzero.me/git/whiterose

bpf_jit_comp.c (25591B)


      1 /*
      2  * BPF JIT compiler for ARM64
      3  *
      4  * Copyright (C) 2014-2016 Zi Shen Lim <zlim.lnx@gmail.com>
      5  *
      6  * This program is free software; you can redistribute it and/or modify
      7  * it under the terms of the GNU General Public License version 2 as
      8  * published by the Free Software Foundation.
      9  *
     10  * This program is distributed in the hope that it will be useful,
     11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
     12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     13  * GNU General Public License for more details.
     14  *
     15  * You should have received a copy of the GNU General Public License
     16  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
     17  */
     18 
     19 #define pr_fmt(fmt) "bpf_jit: " fmt
     20 
     21 #include <linux/bpf.h>
     22 #include <linux/filter.h>
     23 #include <linux/printk.h>
     24 #include <linux/slab.h>
     25 
     26 #include <asm/byteorder.h>
     27 #include <asm/cacheflush.h>
     28 #include <asm/debug-monitors.h>
     29 #include <asm/set_memory.h>
     30 
     31 #include "bpf_jit.h"
     32 
     33 #define TMP_REG_1 (MAX_BPF_JIT_REG + 0)
     34 #define TMP_REG_2 (MAX_BPF_JIT_REG + 1)
     35 #define TCALL_CNT (MAX_BPF_JIT_REG + 2)
     36 #define TMP_REG_3 (MAX_BPF_JIT_REG + 3)
     37 
     38 /* Map BPF registers to A64 registers */
     39 static const int bpf2a64[] = {
     40 	/* return value from in-kernel function, and exit value from eBPF */
     41 	[BPF_REG_0] = A64_R(7),
     42 	/* arguments from eBPF program to in-kernel function */
     43 	[BPF_REG_1] = A64_R(0),
     44 	[BPF_REG_2] = A64_R(1),
     45 	[BPF_REG_3] = A64_R(2),
     46 	[BPF_REG_4] = A64_R(3),
     47 	[BPF_REG_5] = A64_R(4),
     48 	/* callee saved registers that in-kernel function will preserve */
     49 	[BPF_REG_6] = A64_R(19),
     50 	[BPF_REG_7] = A64_R(20),
     51 	[BPF_REG_8] = A64_R(21),
     52 	[BPF_REG_9] = A64_R(22),
     53 	/* read-only frame pointer to access stack */
     54 	[BPF_REG_FP] = A64_R(25),
     55 	/* temporary registers for internal BPF JIT */
     56 	[TMP_REG_1] = A64_R(10),
     57 	[TMP_REG_2] = A64_R(11),
     58 	[TMP_REG_3] = A64_R(12),
     59 	/* tail_call_cnt */
     60 	[TCALL_CNT] = A64_R(26),
     61 	/* temporary register for blinding constants */
     62 	[BPF_REG_AX] = A64_R(9),
     63 };
     64 
     65 struct jit_ctx {
     66 	const struct bpf_prog *prog;
     67 	int idx;
     68 	int epilogue_offset;
     69 	int *offset;
     70 	__le32 *image;
     71 	u32 stack_size;
     72 };
     73 
     74 static inline void emit(const u32 insn, struct jit_ctx *ctx)
     75 {
     76 	if (ctx->image != NULL)
     77 		ctx->image[ctx->idx] = cpu_to_le32(insn);
     78 
     79 	ctx->idx++;
     80 }
     81 
     82 static inline void emit_a64_mov_i(const int is64, const int reg,
     83 				  const s32 val, struct jit_ctx *ctx)
     84 {
     85 	u16 hi = val >> 16;
     86 	u16 lo = val & 0xffff;
     87 
     88 	if (hi & 0x8000) {
     89 		if (hi == 0xffff) {
     90 			emit(A64_MOVN(is64, reg, (u16)~lo, 0), ctx);
     91 		} else {
     92 			emit(A64_MOVN(is64, reg, (u16)~hi, 16), ctx);
     93 			if (lo != 0xffff)
     94 				emit(A64_MOVK(is64, reg, lo, 0), ctx);
     95 		}
     96 	} else {
     97 		emit(A64_MOVZ(is64, reg, lo, 0), ctx);
     98 		if (hi)
     99 			emit(A64_MOVK(is64, reg, hi, 16), ctx);
    100 	}
    101 }
    102 
    103 static int i64_i16_blocks(const u64 val, bool inverse)
    104 {
    105 	return (((val >>  0) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
    106 	       (((val >> 16) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
    107 	       (((val >> 32) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
    108 	       (((val >> 48) & 0xffff) != (inverse ? 0xffff : 0x0000));
    109 }
    110 
    111 static inline void emit_a64_mov_i64(const int reg, const u64 val,
    112 				    struct jit_ctx *ctx)
    113 {
    114 	u64 nrm_tmp = val, rev_tmp = ~val;
    115 	bool inverse;
    116 	int shift;
    117 
    118 	if (!(nrm_tmp >> 32))
    119 		return emit_a64_mov_i(0, reg, (u32)val, ctx);
    120 
    121 	inverse = i64_i16_blocks(nrm_tmp, true) < i64_i16_blocks(nrm_tmp, false);
    122 	shift = max(round_down((inverse ? (fls64(rev_tmp) - 1) :
    123 					  (fls64(nrm_tmp) - 1)), 16), 0);
    124 	if (inverse)
    125 		emit(A64_MOVN(1, reg, (rev_tmp >> shift) & 0xffff, shift), ctx);
    126 	else
    127 		emit(A64_MOVZ(1, reg, (nrm_tmp >> shift) & 0xffff, shift), ctx);
    128 	shift -= 16;
    129 	while (shift >= 0) {
    130 		if (((nrm_tmp >> shift) & 0xffff) != (inverse ? 0xffff : 0x0000))
    131 			emit(A64_MOVK(1, reg, (nrm_tmp >> shift) & 0xffff, shift), ctx);
    132 		shift -= 16;
    133 	}
    134 }
    135 
    136 /*
    137  * Kernel addresses in the vmalloc space use at most 48 bits, and the
    138  * remaining bits are guaranteed to be 0x1. So we can compose the address
    139  * with a fixed length movn/movk/movk sequence.
    140  */
    141 static inline void emit_addr_mov_i64(const int reg, const u64 val,
    142 				     struct jit_ctx *ctx)
    143 {
    144 	u64 tmp = val;
    145 	int shift = 0;
    146 
    147 	emit(A64_MOVN(1, reg, ~tmp & 0xffff, shift), ctx);
    148 	while (shift < 32) {
    149 		tmp >>= 16;
    150 		shift += 16;
    151 		emit(A64_MOVK(1, reg, tmp & 0xffff, shift), ctx);
    152 	}
    153 }
    154 
    155 static inline int bpf2a64_offset(int bpf_to, int bpf_from,
    156 				 const struct jit_ctx *ctx)
    157 {
    158 	int to = ctx->offset[bpf_to];
    159 	/* -1 to account for the Branch instruction */
    160 	int from = ctx->offset[bpf_from] - 1;
    161 
    162 	return to - from;
    163 }
    164 
    165 static void jit_fill_hole(void *area, unsigned int size)
    166 {
    167 	__le32 *ptr;
    168 	/* We are guaranteed to have aligned memory. */
    169 	for (ptr = area; size >= sizeof(u32); size -= sizeof(u32))
    170 		*ptr++ = cpu_to_le32(AARCH64_BREAK_FAULT);
    171 }
    172 
    173 static inline int epilogue_offset(const struct jit_ctx *ctx)
    174 {
    175 	int to = ctx->epilogue_offset;
    176 	int from = ctx->idx;
    177 
    178 	return to - from;
    179 }
    180 
    181 /* Stack must be multiples of 16B */
    182 #define STACK_ALIGN(sz) (((sz) + 15) & ~15)
    183 
    184 /* Tail call offset to jump into */
    185 #define PROLOGUE_OFFSET 7
    186 
    187 static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf)
    188 {
    189 	const struct bpf_prog *prog = ctx->prog;
    190 	const u8 r6 = bpf2a64[BPF_REG_6];
    191 	const u8 r7 = bpf2a64[BPF_REG_7];
    192 	const u8 r8 = bpf2a64[BPF_REG_8];
    193 	const u8 r9 = bpf2a64[BPF_REG_9];
    194 	const u8 fp = bpf2a64[BPF_REG_FP];
    195 	const u8 tcc = bpf2a64[TCALL_CNT];
    196 	const int idx0 = ctx->idx;
    197 	int cur_offset;
    198 
    199 	/*
    200 	 * BPF prog stack layout
    201 	 *
    202 	 *                         high
    203 	 * original A64_SP =>   0:+-----+ BPF prologue
    204 	 *                        |FP/LR|
    205 	 * current A64_FP =>  -16:+-----+
    206 	 *                        | ... | callee saved registers
    207 	 * BPF fp register => -64:+-----+ <= (BPF_FP)
    208 	 *                        |     |
    209 	 *                        | ... | BPF prog stack
    210 	 *                        |     |
    211 	 *                        +-----+ <= (BPF_FP - prog->aux->stack_depth)
    212 	 *                        |RSVD | padding
    213 	 * current A64_SP =>      +-----+ <= (BPF_FP - ctx->stack_size)
    214 	 *                        |     |
    215 	 *                        | ... | Function call stack
    216 	 *                        |     |
    217 	 *                        +-----+
    218 	 *                          low
    219 	 *
    220 	 */
    221 
    222 	/* Save FP and LR registers to stay align with ARM64 AAPCS */
    223 	emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
    224 	emit(A64_MOV(1, A64_FP, A64_SP), ctx);
    225 
    226 	/* Save callee-saved registers */
    227 	emit(A64_PUSH(r6, r7, A64_SP), ctx);
    228 	emit(A64_PUSH(r8, r9, A64_SP), ctx);
    229 	emit(A64_PUSH(fp, tcc, A64_SP), ctx);
    230 
    231 	/* Set up BPF prog stack base register */
    232 	emit(A64_MOV(1, fp, A64_SP), ctx);
    233 
    234 	if (!ebpf_from_cbpf) {
    235 		/* Initialize tail_call_cnt */
    236 		emit(A64_MOVZ(1, tcc, 0, 0), ctx);
    237 
    238 		cur_offset = ctx->idx - idx0;
    239 		if (cur_offset != PROLOGUE_OFFSET) {
    240 			pr_err_once("PROLOGUE_OFFSET = %d, expected %d!\n",
    241 				    cur_offset, PROLOGUE_OFFSET);
    242 			return -1;
    243 		}
    244 	}
    245 
    246 	ctx->stack_size = STACK_ALIGN(prog->aux->stack_depth);
    247 
    248 	/* Set up function call stack */
    249 	emit(A64_SUB_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
    250 	return 0;
    251 }
    252 
    253 static int out_offset = -1; /* initialized on the first pass of build_body() */
    254 static int emit_bpf_tail_call(struct jit_ctx *ctx)
    255 {
    256 	/* bpf_tail_call(void *prog_ctx, struct bpf_array *array, u64 index) */
    257 	const u8 r2 = bpf2a64[BPF_REG_2];
    258 	const u8 r3 = bpf2a64[BPF_REG_3];
    259 
    260 	const u8 tmp = bpf2a64[TMP_REG_1];
    261 	const u8 prg = bpf2a64[TMP_REG_2];
    262 	const u8 tcc = bpf2a64[TCALL_CNT];
    263 	const int idx0 = ctx->idx;
    264 #define cur_offset (ctx->idx - idx0)
    265 #define jmp_offset (out_offset - (cur_offset))
    266 	size_t off;
    267 
    268 	/* if (index >= array->map.max_entries)
    269 	 *     goto out;
    270 	 */
    271 	off = offsetof(struct bpf_array, map.max_entries);
    272 	emit_a64_mov_i64(tmp, off, ctx);
    273 	emit(A64_LDR32(tmp, r2, tmp), ctx);
    274 	emit(A64_MOV(0, r3, r3), ctx);
    275 	emit(A64_CMP(0, r3, tmp), ctx);
    276 	emit(A64_B_(A64_COND_CS, jmp_offset), ctx);
    277 
    278 	/* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
    279 	 *     goto out;
    280 	 * tail_call_cnt++;
    281 	 */
    282 	emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx);
    283 	emit(A64_CMP(1, tcc, tmp), ctx);
    284 	emit(A64_B_(A64_COND_HI, jmp_offset), ctx);
    285 	emit(A64_ADD_I(1, tcc, tcc, 1), ctx);
    286 
    287 	/* prog = array->ptrs[index];
    288 	 * if (prog == NULL)
    289 	 *     goto out;
    290 	 */
    291 	off = offsetof(struct bpf_array, ptrs);
    292 	emit_a64_mov_i64(tmp, off, ctx);
    293 	emit(A64_ADD(1, tmp, r2, tmp), ctx);
    294 	emit(A64_LSL(1, prg, r3, 3), ctx);
    295 	emit(A64_LDR64(prg, tmp, prg), ctx);
    296 	emit(A64_CBZ(1, prg, jmp_offset), ctx);
    297 
    298 	/* goto *(prog->bpf_func + prologue_offset); */
    299 	off = offsetof(struct bpf_prog, bpf_func);
    300 	emit_a64_mov_i64(tmp, off, ctx);
    301 	emit(A64_LDR64(tmp, prg, tmp), ctx);
    302 	emit(A64_ADD_I(1, tmp, tmp, sizeof(u32) * PROLOGUE_OFFSET), ctx);
    303 	emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
    304 	emit(A64_BR(tmp), ctx);
    305 
    306 	/* out: */
    307 	if (out_offset == -1)
    308 		out_offset = cur_offset;
    309 	if (cur_offset != out_offset) {
    310 		pr_err_once("tail_call out_offset = %d, expected %d!\n",
    311 			    cur_offset, out_offset);
    312 		return -1;
    313 	}
    314 	return 0;
    315 #undef cur_offset
    316 #undef jmp_offset
    317 }
    318 
    319 static void build_epilogue(struct jit_ctx *ctx)
    320 {
    321 	const u8 r0 = bpf2a64[BPF_REG_0];
    322 	const u8 r6 = bpf2a64[BPF_REG_6];
    323 	const u8 r7 = bpf2a64[BPF_REG_7];
    324 	const u8 r8 = bpf2a64[BPF_REG_8];
    325 	const u8 r9 = bpf2a64[BPF_REG_9];
    326 	const u8 fp = bpf2a64[BPF_REG_FP];
    327 
    328 	/* We're done with BPF stack */
    329 	emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
    330 
    331 	/* Restore fs (x25) and x26 */
    332 	emit(A64_POP(fp, A64_R(26), A64_SP), ctx);
    333 
    334 	/* Restore callee-saved register */
    335 	emit(A64_POP(r8, r9, A64_SP), ctx);
    336 	emit(A64_POP(r6, r7, A64_SP), ctx);
    337 
    338 	/* Restore FP/LR registers */
    339 	emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
    340 
    341 	/* Set return value */
    342 	emit(A64_MOV(1, A64_R(0), r0), ctx);
    343 
    344 	emit(A64_RET(A64_LR), ctx);
    345 }
    346 
    347 /* JITs an eBPF instruction.
    348  * Returns:
    349  * 0  - successfully JITed an 8-byte eBPF instruction.
    350  * >0 - successfully JITed a 16-byte eBPF instruction.
    351  * <0 - failed to JIT.
    352  */
    353 static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
    354 		      bool extra_pass)
    355 {
    356 	const u8 code = insn->code;
    357 	const u8 dst = bpf2a64[insn->dst_reg];
    358 	const u8 src = bpf2a64[insn->src_reg];
    359 	const u8 tmp = bpf2a64[TMP_REG_1];
    360 	const u8 tmp2 = bpf2a64[TMP_REG_2];
    361 	const u8 tmp3 = bpf2a64[TMP_REG_3];
    362 	const s16 off = insn->off;
    363 	const s32 imm = insn->imm;
    364 	const int i = insn - ctx->prog->insnsi;
    365 	const bool is64 = BPF_CLASS(code) == BPF_ALU64 ||
    366 			  BPF_CLASS(code) == BPF_JMP;
    367 	const bool isdw = BPF_SIZE(code) == BPF_DW;
    368 	u8 jmp_cond;
    369 	s32 jmp_offset;
    370 
    371 #define check_imm(bits, imm) do {				\
    372 	if ((((imm) > 0) && ((imm) >> (bits))) ||		\
    373 	    (((imm) < 0) && (~(imm) >> (bits)))) {		\
    374 		pr_info("[%2d] imm=%d(0x%x) out of range\n",	\
    375 			i, imm, imm);				\
    376 		return -EINVAL;					\
    377 	}							\
    378 } while (0)
    379 #define check_imm19(imm) check_imm(19, imm)
    380 #define check_imm26(imm) check_imm(26, imm)
    381 
    382 	switch (code) {
    383 	/* dst = src */
    384 	case BPF_ALU | BPF_MOV | BPF_X:
    385 	case BPF_ALU64 | BPF_MOV | BPF_X:
    386 		emit(A64_MOV(is64, dst, src), ctx);
    387 		break;
    388 	/* dst = dst OP src */
    389 	case BPF_ALU | BPF_ADD | BPF_X:
    390 	case BPF_ALU64 | BPF_ADD | BPF_X:
    391 		emit(A64_ADD(is64, dst, dst, src), ctx);
    392 		break;
    393 	case BPF_ALU | BPF_SUB | BPF_X:
    394 	case BPF_ALU64 | BPF_SUB | BPF_X:
    395 		emit(A64_SUB(is64, dst, dst, src), ctx);
    396 		break;
    397 	case BPF_ALU | BPF_AND | BPF_X:
    398 	case BPF_ALU64 | BPF_AND | BPF_X:
    399 		emit(A64_AND(is64, dst, dst, src), ctx);
    400 		break;
    401 	case BPF_ALU | BPF_OR | BPF_X:
    402 	case BPF_ALU64 | BPF_OR | BPF_X:
    403 		emit(A64_ORR(is64, dst, dst, src), ctx);
    404 		break;
    405 	case BPF_ALU | BPF_XOR | BPF_X:
    406 	case BPF_ALU64 | BPF_XOR | BPF_X:
    407 		emit(A64_EOR(is64, dst, dst, src), ctx);
    408 		break;
    409 	case BPF_ALU | BPF_MUL | BPF_X:
    410 	case BPF_ALU64 | BPF_MUL | BPF_X:
    411 		emit(A64_MUL(is64, dst, dst, src), ctx);
    412 		break;
    413 	case BPF_ALU | BPF_DIV | BPF_X:
    414 	case BPF_ALU64 | BPF_DIV | BPF_X:
    415 	case BPF_ALU | BPF_MOD | BPF_X:
    416 	case BPF_ALU64 | BPF_MOD | BPF_X:
    417 		switch (BPF_OP(code)) {
    418 		case BPF_DIV:
    419 			emit(A64_UDIV(is64, dst, dst, src), ctx);
    420 			break;
    421 		case BPF_MOD:
    422 			emit(A64_UDIV(is64, tmp, dst, src), ctx);
    423 			emit(A64_MUL(is64, tmp, tmp, src), ctx);
    424 			emit(A64_SUB(is64, dst, dst, tmp), ctx);
    425 			break;
    426 		}
    427 		break;
    428 	case BPF_ALU | BPF_LSH | BPF_X:
    429 	case BPF_ALU64 | BPF_LSH | BPF_X:
    430 		emit(A64_LSLV(is64, dst, dst, src), ctx);
    431 		break;
    432 	case BPF_ALU | BPF_RSH | BPF_X:
    433 	case BPF_ALU64 | BPF_RSH | BPF_X:
    434 		emit(A64_LSRV(is64, dst, dst, src), ctx);
    435 		break;
    436 	case BPF_ALU | BPF_ARSH | BPF_X:
    437 	case BPF_ALU64 | BPF_ARSH | BPF_X:
    438 		emit(A64_ASRV(is64, dst, dst, src), ctx);
    439 		break;
    440 	/* dst = -dst */
    441 	case BPF_ALU | BPF_NEG:
    442 	case BPF_ALU64 | BPF_NEG:
    443 		emit(A64_NEG(is64, dst, dst), ctx);
    444 		break;
    445 	/* dst = BSWAP##imm(dst) */
    446 	case BPF_ALU | BPF_END | BPF_FROM_LE:
    447 	case BPF_ALU | BPF_END | BPF_FROM_BE:
    448 #ifdef CONFIG_CPU_BIG_ENDIAN
    449 		if (BPF_SRC(code) == BPF_FROM_BE)
    450 			goto emit_bswap_uxt;
    451 #else /* !CONFIG_CPU_BIG_ENDIAN */
    452 		if (BPF_SRC(code) == BPF_FROM_LE)
    453 			goto emit_bswap_uxt;
    454 #endif
    455 		switch (imm) {
    456 		case 16:
    457 			emit(A64_REV16(is64, dst, dst), ctx);
    458 			/* zero-extend 16 bits into 64 bits */
    459 			emit(A64_UXTH(is64, dst, dst), ctx);
    460 			break;
    461 		case 32:
    462 			emit(A64_REV32(is64, dst, dst), ctx);
    463 			/* upper 32 bits already cleared */
    464 			break;
    465 		case 64:
    466 			emit(A64_REV64(dst, dst), ctx);
    467 			break;
    468 		}
    469 		break;
    470 emit_bswap_uxt:
    471 		switch (imm) {
    472 		case 16:
    473 			/* zero-extend 16 bits into 64 bits */
    474 			emit(A64_UXTH(is64, dst, dst), ctx);
    475 			break;
    476 		case 32:
    477 			/* zero-extend 32 bits into 64 bits */
    478 			emit(A64_UXTW(is64, dst, dst), ctx);
    479 			break;
    480 		case 64:
    481 			/* nop */
    482 			break;
    483 		}
    484 		break;
    485 	/* dst = imm */
    486 	case BPF_ALU | BPF_MOV | BPF_K:
    487 	case BPF_ALU64 | BPF_MOV | BPF_K:
    488 		emit_a64_mov_i(is64, dst, imm, ctx);
    489 		break;
    490 	/* dst = dst OP imm */
    491 	case BPF_ALU | BPF_ADD | BPF_K:
    492 	case BPF_ALU64 | BPF_ADD | BPF_K:
    493 		emit_a64_mov_i(is64, tmp, imm, ctx);
    494 		emit(A64_ADD(is64, dst, dst, tmp), ctx);
    495 		break;
    496 	case BPF_ALU | BPF_SUB | BPF_K:
    497 	case BPF_ALU64 | BPF_SUB | BPF_K:
    498 		emit_a64_mov_i(is64, tmp, imm, ctx);
    499 		emit(A64_SUB(is64, dst, dst, tmp), ctx);
    500 		break;
    501 	case BPF_ALU | BPF_AND | BPF_K:
    502 	case BPF_ALU64 | BPF_AND | BPF_K:
    503 		emit_a64_mov_i(is64, tmp, imm, ctx);
    504 		emit(A64_AND(is64, dst, dst, tmp), ctx);
    505 		break;
    506 	case BPF_ALU | BPF_OR | BPF_K:
    507 	case BPF_ALU64 | BPF_OR | BPF_K:
    508 		emit_a64_mov_i(is64, tmp, imm, ctx);
    509 		emit(A64_ORR(is64, dst, dst, tmp), ctx);
    510 		break;
    511 	case BPF_ALU | BPF_XOR | BPF_K:
    512 	case BPF_ALU64 | BPF_XOR | BPF_K:
    513 		emit_a64_mov_i(is64, tmp, imm, ctx);
    514 		emit(A64_EOR(is64, dst, dst, tmp), ctx);
    515 		break;
    516 	case BPF_ALU | BPF_MUL | BPF_K:
    517 	case BPF_ALU64 | BPF_MUL | BPF_K:
    518 		emit_a64_mov_i(is64, tmp, imm, ctx);
    519 		emit(A64_MUL(is64, dst, dst, tmp), ctx);
    520 		break;
    521 	case BPF_ALU | BPF_DIV | BPF_K:
    522 	case BPF_ALU64 | BPF_DIV | BPF_K:
    523 		emit_a64_mov_i(is64, tmp, imm, ctx);
    524 		emit(A64_UDIV(is64, dst, dst, tmp), ctx);
    525 		break;
    526 	case BPF_ALU | BPF_MOD | BPF_K:
    527 	case BPF_ALU64 | BPF_MOD | BPF_K:
    528 		emit_a64_mov_i(is64, tmp2, imm, ctx);
    529 		emit(A64_UDIV(is64, tmp, dst, tmp2), ctx);
    530 		emit(A64_MUL(is64, tmp, tmp, tmp2), ctx);
    531 		emit(A64_SUB(is64, dst, dst, tmp), ctx);
    532 		break;
    533 	case BPF_ALU | BPF_LSH | BPF_K:
    534 	case BPF_ALU64 | BPF_LSH | BPF_K:
    535 		emit(A64_LSL(is64, dst, dst, imm), ctx);
    536 		break;
    537 	case BPF_ALU | BPF_RSH | BPF_K:
    538 	case BPF_ALU64 | BPF_RSH | BPF_K:
    539 		emit(A64_LSR(is64, dst, dst, imm), ctx);
    540 		break;
    541 	case BPF_ALU | BPF_ARSH | BPF_K:
    542 	case BPF_ALU64 | BPF_ARSH | BPF_K:
    543 		emit(A64_ASR(is64, dst, dst, imm), ctx);
    544 		break;
    545 
    546 	/* JUMP off */
    547 	case BPF_JMP | BPF_JA:
    548 		jmp_offset = bpf2a64_offset(i + off, i, ctx);
    549 		check_imm26(jmp_offset);
    550 		emit(A64_B(jmp_offset), ctx);
    551 		break;
    552 	/* IF (dst COND src) JUMP off */
    553 	case BPF_JMP | BPF_JEQ | BPF_X:
    554 	case BPF_JMP | BPF_JGT | BPF_X:
    555 	case BPF_JMP | BPF_JLT | BPF_X:
    556 	case BPF_JMP | BPF_JGE | BPF_X:
    557 	case BPF_JMP | BPF_JLE | BPF_X:
    558 	case BPF_JMP | BPF_JNE | BPF_X:
    559 	case BPF_JMP | BPF_JSGT | BPF_X:
    560 	case BPF_JMP | BPF_JSLT | BPF_X:
    561 	case BPF_JMP | BPF_JSGE | BPF_X:
    562 	case BPF_JMP | BPF_JSLE | BPF_X:
    563 	case BPF_JMP32 | BPF_JEQ | BPF_X:
    564 	case BPF_JMP32 | BPF_JGT | BPF_X:
    565 	case BPF_JMP32 | BPF_JLT | BPF_X:
    566 	case BPF_JMP32 | BPF_JGE | BPF_X:
    567 	case BPF_JMP32 | BPF_JLE | BPF_X:
    568 	case BPF_JMP32 | BPF_JNE | BPF_X:
    569 	case BPF_JMP32 | BPF_JSGT | BPF_X:
    570 	case BPF_JMP32 | BPF_JSLT | BPF_X:
    571 	case BPF_JMP32 | BPF_JSGE | BPF_X:
    572 	case BPF_JMP32 | BPF_JSLE | BPF_X:
    573 		emit(A64_CMP(is64, dst, src), ctx);
    574 emit_cond_jmp:
    575 		jmp_offset = bpf2a64_offset(i + off, i, ctx);
    576 		check_imm19(jmp_offset);
    577 		switch (BPF_OP(code)) {
    578 		case BPF_JEQ:
    579 			jmp_cond = A64_COND_EQ;
    580 			break;
    581 		case BPF_JGT:
    582 			jmp_cond = A64_COND_HI;
    583 			break;
    584 		case BPF_JLT:
    585 			jmp_cond = A64_COND_CC;
    586 			break;
    587 		case BPF_JGE:
    588 			jmp_cond = A64_COND_CS;
    589 			break;
    590 		case BPF_JLE:
    591 			jmp_cond = A64_COND_LS;
    592 			break;
    593 		case BPF_JSET:
    594 		case BPF_JNE:
    595 			jmp_cond = A64_COND_NE;
    596 			break;
    597 		case BPF_JSGT:
    598 			jmp_cond = A64_COND_GT;
    599 			break;
    600 		case BPF_JSLT:
    601 			jmp_cond = A64_COND_LT;
    602 			break;
    603 		case BPF_JSGE:
    604 			jmp_cond = A64_COND_GE;
    605 			break;
    606 		case BPF_JSLE:
    607 			jmp_cond = A64_COND_LE;
    608 			break;
    609 		default:
    610 			return -EFAULT;
    611 		}
    612 		emit(A64_B_(jmp_cond, jmp_offset), ctx);
    613 		break;
    614 	case BPF_JMP | BPF_JSET | BPF_X:
    615 	case BPF_JMP32 | BPF_JSET | BPF_X:
    616 		emit(A64_TST(is64, dst, src), ctx);
    617 		goto emit_cond_jmp;
    618 	/* IF (dst COND imm) JUMP off */
    619 	case BPF_JMP | BPF_JEQ | BPF_K:
    620 	case BPF_JMP | BPF_JGT | BPF_K:
    621 	case BPF_JMP | BPF_JLT | BPF_K:
    622 	case BPF_JMP | BPF_JGE | BPF_K:
    623 	case BPF_JMP | BPF_JLE | BPF_K:
    624 	case BPF_JMP | BPF_JNE | BPF_K:
    625 	case BPF_JMP | BPF_JSGT | BPF_K:
    626 	case BPF_JMP | BPF_JSLT | BPF_K:
    627 	case BPF_JMP | BPF_JSGE | BPF_K:
    628 	case BPF_JMP | BPF_JSLE | BPF_K:
    629 	case BPF_JMP32 | BPF_JEQ | BPF_K:
    630 	case BPF_JMP32 | BPF_JGT | BPF_K:
    631 	case BPF_JMP32 | BPF_JLT | BPF_K:
    632 	case BPF_JMP32 | BPF_JGE | BPF_K:
    633 	case BPF_JMP32 | BPF_JLE | BPF_K:
    634 	case BPF_JMP32 | BPF_JNE | BPF_K:
    635 	case BPF_JMP32 | BPF_JSGT | BPF_K:
    636 	case BPF_JMP32 | BPF_JSLT | BPF_K:
    637 	case BPF_JMP32 | BPF_JSGE | BPF_K:
    638 	case BPF_JMP32 | BPF_JSLE | BPF_K:
    639 		emit_a64_mov_i(is64, tmp, imm, ctx);
    640 		emit(A64_CMP(is64, dst, tmp), ctx);
    641 		goto emit_cond_jmp;
    642 	case BPF_JMP | BPF_JSET | BPF_K:
    643 	case BPF_JMP32 | BPF_JSET | BPF_K:
    644 		emit_a64_mov_i(is64, tmp, imm, ctx);
    645 		emit(A64_TST(is64, dst, tmp), ctx);
    646 		goto emit_cond_jmp;
    647 	/* function call */
    648 	case BPF_JMP | BPF_CALL:
    649 	{
    650 		const u8 r0 = bpf2a64[BPF_REG_0];
    651 		bool func_addr_fixed;
    652 		u64 func_addr;
    653 		int ret;
    654 
    655 		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
    656 					    &func_addr, &func_addr_fixed);
    657 		if (ret < 0)
    658 			return ret;
    659 		emit_addr_mov_i64(tmp, func_addr, ctx);
    660 		emit(A64_BLR(tmp), ctx);
    661 		emit(A64_MOV(1, r0, A64_R(0)), ctx);
    662 		break;
    663 	}
    664 	/* tail call */
    665 	case BPF_JMP | BPF_TAIL_CALL:
    666 		if (emit_bpf_tail_call(ctx))
    667 			return -EFAULT;
    668 		break;
    669 	/* function return */
    670 	case BPF_JMP | BPF_EXIT:
    671 		/* Optimization: when last instruction is EXIT,
    672 		   simply fallthrough to epilogue. */
    673 		if (i == ctx->prog->len - 1)
    674 			break;
    675 		jmp_offset = epilogue_offset(ctx);
    676 		check_imm26(jmp_offset);
    677 		emit(A64_B(jmp_offset), ctx);
    678 		break;
    679 
    680 	/* dst = imm64 */
    681 	case BPF_LD | BPF_IMM | BPF_DW:
    682 	{
    683 		const struct bpf_insn insn1 = insn[1];
    684 		u64 imm64;
    685 
    686 		imm64 = (u64)insn1.imm << 32 | (u32)imm;
    687 		emit_a64_mov_i64(dst, imm64, ctx);
    688 
    689 		return 1;
    690 	}
    691 
    692 	/* LDX: dst = *(size *)(src + off) */
    693 	case BPF_LDX | BPF_MEM | BPF_W:
    694 	case BPF_LDX | BPF_MEM | BPF_H:
    695 	case BPF_LDX | BPF_MEM | BPF_B:
    696 	case BPF_LDX | BPF_MEM | BPF_DW:
    697 		emit_a64_mov_i(1, tmp, off, ctx);
    698 		switch (BPF_SIZE(code)) {
    699 		case BPF_W:
    700 			emit(A64_LDR32(dst, src, tmp), ctx);
    701 			break;
    702 		case BPF_H:
    703 			emit(A64_LDRH(dst, src, tmp), ctx);
    704 			break;
    705 		case BPF_B:
    706 			emit(A64_LDRB(dst, src, tmp), ctx);
    707 			break;
    708 		case BPF_DW:
    709 			emit(A64_LDR64(dst, src, tmp), ctx);
    710 			break;
    711 		}
    712 		break;
    713 
    714 	/* ST: *(size *)(dst + off) = imm */
    715 	case BPF_ST | BPF_MEM | BPF_W:
    716 	case BPF_ST | BPF_MEM | BPF_H:
    717 	case BPF_ST | BPF_MEM | BPF_B:
    718 	case BPF_ST | BPF_MEM | BPF_DW:
    719 		/* Load imm to a register then store it */
    720 		emit_a64_mov_i(1, tmp2, off, ctx);
    721 		emit_a64_mov_i(1, tmp, imm, ctx);
    722 		switch (BPF_SIZE(code)) {
    723 		case BPF_W:
    724 			emit(A64_STR32(tmp, dst, tmp2), ctx);
    725 			break;
    726 		case BPF_H:
    727 			emit(A64_STRH(tmp, dst, tmp2), ctx);
    728 			break;
    729 		case BPF_B:
    730 			emit(A64_STRB(tmp, dst, tmp2), ctx);
    731 			break;
    732 		case BPF_DW:
    733 			emit(A64_STR64(tmp, dst, tmp2), ctx);
    734 			break;
    735 		}
    736 		break;
    737 
    738 	/* STX: *(size *)(dst + off) = src */
    739 	case BPF_STX | BPF_MEM | BPF_W:
    740 	case BPF_STX | BPF_MEM | BPF_H:
    741 	case BPF_STX | BPF_MEM | BPF_B:
    742 	case BPF_STX | BPF_MEM | BPF_DW:
    743 		emit_a64_mov_i(1, tmp, off, ctx);
    744 		switch (BPF_SIZE(code)) {
    745 		case BPF_W:
    746 			emit(A64_STR32(src, dst, tmp), ctx);
    747 			break;
    748 		case BPF_H:
    749 			emit(A64_STRH(src, dst, tmp), ctx);
    750 			break;
    751 		case BPF_B:
    752 			emit(A64_STRB(src, dst, tmp), ctx);
    753 			break;
    754 		case BPF_DW:
    755 			emit(A64_STR64(src, dst, tmp), ctx);
    756 			break;
    757 		}
    758 		break;
    759 	/* STX XADD: lock *(u32 *)(dst + off) += src */
    760 	case BPF_STX | BPF_XADD | BPF_W:
    761 	/* STX XADD: lock *(u64 *)(dst + off) += src */
    762 	case BPF_STX | BPF_XADD | BPF_DW:
    763 		emit_a64_mov_i(1, tmp, off, ctx);
    764 		emit(A64_ADD(1, tmp, tmp, dst), ctx);
    765 		emit(A64_PRFM(tmp, PST, L1, STRM), ctx);
    766 		emit(A64_LDXR(isdw, tmp2, tmp), ctx);
    767 		emit(A64_ADD(isdw, tmp2, tmp2, src), ctx);
    768 		emit(A64_STXR(isdw, tmp2, tmp, tmp3), ctx);
    769 		jmp_offset = -3;
    770 		check_imm19(jmp_offset);
    771 		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
    772 		break;
    773 
    774 	default:
    775 		pr_err_once("unknown opcode %02x\n", code);
    776 		return -EINVAL;
    777 	}
    778 
    779 	return 0;
    780 }
    781 
    782 static int build_body(struct jit_ctx *ctx, bool extra_pass)
    783 {
    784 	const struct bpf_prog *prog = ctx->prog;
    785 	int i;
    786 
    787 	for (i = 0; i < prog->len; i++) {
    788 		const struct bpf_insn *insn = &prog->insnsi[i];
    789 		int ret;
    790 
    791 		ret = build_insn(insn, ctx, extra_pass);
    792 		if (ret > 0) {
    793 			i++;
    794 			if (ctx->image == NULL)
    795 				ctx->offset[i] = ctx->idx;
    796 			continue;
    797 		}
    798 		if (ctx->image == NULL)
    799 			ctx->offset[i] = ctx->idx;
    800 		if (ret)
    801 			return ret;
    802 	}
    803 
    804 	return 0;
    805 }
    806 
    807 static int validate_code(struct jit_ctx *ctx)
    808 {
    809 	int i;
    810 
    811 	for (i = 0; i < ctx->idx; i++) {
    812 		u32 a64_insn = le32_to_cpu(ctx->image[i]);
    813 
    814 		if (a64_insn == AARCH64_BREAK_FAULT)
    815 			return -1;
    816 	}
    817 
    818 	return 0;
    819 }
    820 
    821 static inline void bpf_flush_icache(void *start, void *end)
    822 {
    823 	flush_icache_range((unsigned long)start, (unsigned long)end);
    824 }
    825 
    826 struct arm64_jit_data {
    827 	struct bpf_binary_header *header;
    828 	u8 *image;
    829 	struct jit_ctx ctx;
    830 };
    831 
    832 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
    833 {
    834 	struct bpf_prog *tmp, *orig_prog = prog;
    835 	struct bpf_binary_header *header;
    836 	struct arm64_jit_data *jit_data;
    837 	bool was_classic = bpf_prog_was_classic(prog);
    838 	bool tmp_blinded = false;
    839 	bool extra_pass = false;
    840 	struct jit_ctx ctx;
    841 	int image_size;
    842 	u8 *image_ptr;
    843 
    844 	if (!prog->jit_requested)
    845 		return orig_prog;
    846 
    847 	tmp = bpf_jit_blind_constants(prog);
    848 	/* If blinding was requested and we failed during blinding,
    849 	 * we must fall back to the interpreter.
    850 	 */
    851 	if (IS_ERR(tmp))
    852 		return orig_prog;
    853 	if (tmp != prog) {
    854 		tmp_blinded = true;
    855 		prog = tmp;
    856 	}
    857 
    858 	jit_data = prog->aux->jit_data;
    859 	if (!jit_data) {
    860 		jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
    861 		if (!jit_data) {
    862 			prog = orig_prog;
    863 			goto out;
    864 		}
    865 		prog->aux->jit_data = jit_data;
    866 	}
    867 	if (jit_data->ctx.offset) {
    868 		ctx = jit_data->ctx;
    869 		image_ptr = jit_data->image;
    870 		header = jit_data->header;
    871 		extra_pass = true;
    872 		image_size = sizeof(u32) * ctx.idx;
    873 		goto skip_init_ctx;
    874 	}
    875 	memset(&ctx, 0, sizeof(ctx));
    876 	ctx.prog = prog;
    877 
    878 	ctx.offset = kcalloc(prog->len, sizeof(int), GFP_KERNEL);
    879 	if (ctx.offset == NULL) {
    880 		prog = orig_prog;
    881 		goto out_off;
    882 	}
    883 
    884 	/* 1. Initial fake pass to compute ctx->idx. */
    885 
    886 	/* Fake pass to fill in ctx->offset. */
    887 	if (build_body(&ctx, extra_pass)) {
    888 		prog = orig_prog;
    889 		goto out_off;
    890 	}
    891 
    892 	if (build_prologue(&ctx, was_classic)) {
    893 		prog = orig_prog;
    894 		goto out_off;
    895 	}
    896 
    897 	ctx.epilogue_offset = ctx.idx;
    898 	build_epilogue(&ctx);
    899 
    900 	/* Now we know the actual image size. */
    901 	image_size = sizeof(u32) * ctx.idx;
    902 	header = bpf_jit_binary_alloc(image_size, &image_ptr,
    903 				      sizeof(u32), jit_fill_hole);
    904 	if (header == NULL) {
    905 		prog = orig_prog;
    906 		goto out_off;
    907 	}
    908 
    909 	/* 2. Now, the actual pass. */
    910 
    911 	ctx.image = (__le32 *)image_ptr;
    912 skip_init_ctx:
    913 	ctx.idx = 0;
    914 
    915 	build_prologue(&ctx, was_classic);
    916 
    917 	if (build_body(&ctx, extra_pass)) {
    918 		bpf_jit_binary_free(header);
    919 		prog = orig_prog;
    920 		goto out_off;
    921 	}
    922 
    923 	build_epilogue(&ctx);
    924 
    925 	/* 3. Extra pass to validate JITed code. */
    926 	if (validate_code(&ctx)) {
    927 		bpf_jit_binary_free(header);
    928 		prog = orig_prog;
    929 		goto out_off;
    930 	}
    931 
    932 	/* And we're done. */
    933 	if (bpf_jit_enable > 1)
    934 		bpf_jit_dump(prog->len, image_size, 2, ctx.image);
    935 
    936 	bpf_flush_icache(header, ctx.image + ctx.idx);
    937 
    938 	if (!prog->is_func || extra_pass) {
    939 		if (extra_pass && ctx.idx != jit_data->ctx.idx) {
    940 			pr_err_once("multi-func JIT bug %d != %d\n",
    941 				    ctx.idx, jit_data->ctx.idx);
    942 			bpf_jit_binary_free(header);
    943 			prog->bpf_func = NULL;
    944 			prog->jited = 0;
    945 			goto out_off;
    946 		}
    947 		bpf_jit_binary_lock_ro(header);
    948 	} else {
    949 		jit_data->ctx = ctx;
    950 		jit_data->image = image_ptr;
    951 		jit_data->header = header;
    952 	}
    953 	prog->bpf_func = (void *)ctx.image;
    954 	prog->jited = 1;
    955 	prog->jited_len = image_size;
    956 
    957 	if (!prog->is_func || extra_pass) {
    958 		bpf_prog_fill_jited_linfo(prog, ctx.offset);
    959 out_off:
    960 		kfree(ctx.offset);
    961 		kfree(jit_data);
    962 		prog->aux->jit_data = NULL;
    963 	}
    964 out:
    965 	if (tmp_blinded)
    966 		bpf_jit_prog_release_other(prog, prog == orig_prog ?
    967 					   tmp : orig_prog);
    968 	return prog;
    969 }
    970 
    971 void *bpf_jit_alloc_exec(unsigned long size)
    972 {
    973 	return __vmalloc_node_range(size, PAGE_SIZE, BPF_JIT_REGION_START,
    974 				    BPF_JIT_REGION_END, GFP_KERNEL,
    975 				    PAGE_KERNEL_EXEC, 0, NUMA_NO_NODE,
    976 				    __builtin_return_address(0));
    977 }
    978 
    979 void bpf_jit_free_exec(void *addr)
    980 {
    981 	return vfree(addr);
    982 }